mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 11:08:52 +00:00
Merge branch 'ivarflakstad/seperate-benchmarks-by-feature' into ivarflakstad/metal-fill
This commit is contained in:
4
.github/workflows/ci_cuda.yaml
vendored
4
.github/workflows/ci_cuda.yaml
vendored
@ -8,6 +8,8 @@ jobs:
|
|||||||
start-runner:
|
start-runner:
|
||||||
name: Start self-hosted EC2 runner
|
name: Start self-hosted EC2 runner
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
|
# Don't run on forks, they won't have access to secrets anyway.
|
||||||
|
if: ${{ github.event.pull_request.head.repo.full_name == github.event.pull_request.base.repo.full_name }}
|
||||||
env:
|
env:
|
||||||
AWS_REGION: us-east-1
|
AWS_REGION: us-east-1
|
||||||
EC2_AMI_ID: ami-03cfed9ea28f4b002
|
EC2_AMI_ID: ami-03cfed9ea28f4b002
|
||||||
@ -70,7 +72,7 @@ jobs:
|
|||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
env:
|
env:
|
||||||
AWS_REGION: us-east-1
|
AWS_REGION: us-east-1
|
||||||
if: ${{ always() }} # required to stop the runner even if the error happened in the previous jobs
|
if: ${{ (success() || failure()) && github.event.pull_request.head.repo.full_name == github.event.pull_request.base.repo.full_name }} # required to stop the runner even if the error happened in the previous jobs
|
||||||
steps:
|
steps:
|
||||||
- name: Configure AWS credentials
|
- name: Configure AWS credentials
|
||||||
uses: aws-actions/configure-aws-credentials@v1
|
uses: aws-actions/configure-aws-credentials@v1
|
||||||
|
@ -158,6 +158,7 @@ And then head over to
|
|||||||
- [`candle-ext`](https://github.com/mokeyish/candle-ext): An extension library to Candle that provides PyTorch functions not currently available in Candle.
|
- [`candle-ext`](https://github.com/mokeyish/candle-ext): An extension library to Candle that provides PyTorch functions not currently available in Candle.
|
||||||
- [`kalosm`](https://github.com/floneum/floneum/tree/master/interfaces/kalosm): A multi-modal meta-framework in Rust for interfacing with local pre-trained models with support for controlled generation, custom samplers, in-memory vector databases, audio transcription, and more.
|
- [`kalosm`](https://github.com/floneum/floneum/tree/master/interfaces/kalosm): A multi-modal meta-framework in Rust for interfacing with local pre-trained models with support for controlled generation, custom samplers, in-memory vector databases, audio transcription, and more.
|
||||||
- [`candle-sampling`](https://github.com/EricLBuehler/candle-sampling): Sampling techniques for Candle.
|
- [`candle-sampling`](https://github.com/EricLBuehler/candle-sampling): Sampling techniques for Candle.
|
||||||
|
- [`gpt-from-scratch-rs`](https://github.com/jeroenvlek/gpt-from-scratch-rs): A port of Andrej Karpathy's _Let's build GPT_ tutorial on YouTube showcasing the Candle API on a toy problem.
|
||||||
|
|
||||||
If you have an addition to this list, please submit a pull request.
|
If you have an addition to this list, please submit a pull request.
|
||||||
|
|
||||||
|
@ -28,6 +28,7 @@ let weights = candle::safetensors::load(weights_filename, &Device::Cpu).unwrap()
|
|||||||
#[rustfmt::skip]
|
#[rustfmt::skip]
|
||||||
#[test]
|
#[test]
|
||||||
fn book_hub_2() {
|
fn book_hub_2() {
|
||||||
|
{
|
||||||
// ANCHOR: book_hub_2
|
// ANCHOR: book_hub_2
|
||||||
use candle::Device;
|
use candle::Device;
|
||||||
use hf_hub::api::sync::Api;
|
use hf_hub::api::sync::Api;
|
||||||
@ -45,9 +46,10 @@ let weights = candle::safetensors::load_buffer(&mmap[..], &Device::Cpu).unwrap()
|
|||||||
assert_eq!(weights.len(), 206);
|
assert_eq!(weights.len(), 206);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[rustfmt::skip]
|
// #[rustfmt::skip]
|
||||||
#[test]
|
// #[test]
|
||||||
fn book_hub_3() {
|
// fn book_hub_3() {
|
||||||
|
{
|
||||||
// ANCHOR: book_hub_3
|
// ANCHOR: book_hub_3
|
||||||
use candle::{DType, Device, Tensor};
|
use candle::{DType, Device, Tensor};
|
||||||
use hf_hub::api::sync::Api;
|
use hf_hub::api::sync::Api;
|
||||||
@ -102,6 +104,7 @@ let tp_tensor = Tensor::from_raw_buffer(&raw, dtype, &tp_shape, &Device::Cpu).un
|
|||||||
assert_eq!(view.shape(), &[768, 768]);
|
assert_eq!(view.shape(), &[768, 768]);
|
||||||
assert_eq!(tp_tensor.dims(), &[192, 768]);
|
assert_eq!(tp_tensor.dims(), &[192, 768]);
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[rustfmt::skip]
|
#[rustfmt::skip]
|
||||||
#[test]
|
#[test]
|
||||||
|
@ -46,7 +46,7 @@ accelerate = ["dep:libc", "dep:accelerate-src"]
|
|||||||
metal = ["dep:metal", "dep:candle-metal-kernels"]
|
metal = ["dep:metal", "dep:candle-metal-kernels"]
|
||||||
|
|
||||||
[[bench]]
|
[[bench]]
|
||||||
name = "matmul"
|
name = "bench_main"
|
||||||
harness = false
|
harness = false
|
||||||
|
|
||||||
[[bench]]
|
[[bench]]
|
||||||
|
4
candle-core/benches/bench_main.rs
Normal file
4
candle-core/benches/bench_main.rs
Normal file
@ -0,0 +1,4 @@
|
|||||||
|
mod benchmarks;
|
||||||
|
|
||||||
|
use criterion::criterion_main;
|
||||||
|
criterion_main!(benchmarks::matmul::benches);
|
@ -1,5 +1,6 @@
|
|||||||
use candle_core::{DType, Device, Tensor};
|
use crate::benchmarks::{bench_name, device, BenchDevice};
|
||||||
use criterion::{black_box, criterion_group, criterion_main, Criterion, Throughput};
|
use candle_core::{DType, Tensor};
|
||||||
|
use criterion::{black_box, criterion_group, Criterion, Throughput};
|
||||||
use std::time::Instant;
|
use std::time::Instant;
|
||||||
|
|
||||||
fn run(a: &Tensor, b: &Tensor) {
|
fn run(a: &Tensor, b: &Tensor) {
|
||||||
@ -12,14 +13,14 @@ fn criterion_benchmark(c: &mut Criterion) {
|
|||||||
let n = 2048;
|
let n = 2048;
|
||||||
let k = 2048;
|
let k = 2048;
|
||||||
|
|
||||||
let device = Device::new_metal(0).unwrap();
|
let device = device().unwrap();
|
||||||
let dtype = DType::F32;
|
let dtype = DType::F32;
|
||||||
let lhs = Tensor::zeros((b, m, k), dtype, &device).unwrap();
|
let lhs = Tensor::zeros((b, m, k), dtype, &device).unwrap();
|
||||||
let rhs = Tensor::zeros((b, n, k), dtype, &device).unwrap();
|
let rhs = Tensor::zeros((b, n, k), dtype, &device).unwrap();
|
||||||
|
|
||||||
let flops = b * m * n * k;
|
let flops = b * m * n * k;
|
||||||
|
|
||||||
let mut group = c.benchmark_group("matmul_metal");
|
let mut group = c.benchmark_group(bench_name("matmul"));
|
||||||
group.throughput(Throughput::Bytes(flops as u64));
|
group.throughput(Throughput::Bytes(flops as u64));
|
||||||
group.bench_function("iter", move |b| {
|
group.bench_function("iter", move |b| {
|
||||||
b.iter_custom(|iters| {
|
b.iter_custom(|iters| {
|
||||||
@ -27,11 +28,7 @@ fn criterion_benchmark(c: &mut Criterion) {
|
|||||||
for _i in 0..iters {
|
for _i in 0..iters {
|
||||||
run(black_box(&lhs), black_box(&rhs));
|
run(black_box(&lhs), black_box(&rhs));
|
||||||
}
|
}
|
||||||
if let Device::Metal(device) = &device {
|
device.sync().unwrap();
|
||||||
device.wait_until_completed().unwrap();
|
|
||||||
} else {
|
|
||||||
panic!("Expected metal device");
|
|
||||||
}
|
|
||||||
start.elapsed()
|
start.elapsed()
|
||||||
})
|
})
|
||||||
});
|
});
|
||||||
@ -39,4 +36,3 @@ fn criterion_benchmark(c: &mut Criterion) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
criterion_group!(benches, criterion_benchmark);
|
criterion_group!(benches, criterion_benchmark);
|
||||||
criterion_main!(benches);
|
|
55
candle-core/benches/benchmarks/mod.rs
Normal file
55
candle-core/benches/benchmarks/mod.rs
Normal file
@ -0,0 +1,55 @@
|
|||||||
|
pub(crate) mod matmul;
|
||||||
|
|
||||||
|
use candle_core::{Device, Result};
|
||||||
|
|
||||||
|
pub(crate) trait BenchDevice {
|
||||||
|
fn sync(&self) -> Result<()>;
|
||||||
|
}
|
||||||
|
|
||||||
|
impl BenchDevice for Device {
|
||||||
|
fn sync(&self) -> Result<()> {
|
||||||
|
match self {
|
||||||
|
Device::Cpu => Ok(()),
|
||||||
|
Device::Cuda(device) => {
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
return Ok(device.synchronize()?);
|
||||||
|
#[cfg(not(feature = "cuda"))]
|
||||||
|
panic!("Cuda device without cuda feature enabled: {:?}", device)
|
||||||
|
}
|
||||||
|
Device::Metal(device) => {
|
||||||
|
#[cfg(feature = "metal")]
|
||||||
|
return Ok(device.wait_until_completed()?);
|
||||||
|
#[cfg(not(feature = "metal"))]
|
||||||
|
panic!("Metal device without metal feature enabled: {:?}", device)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn device() -> Result<Device> {
|
||||||
|
if cfg!(feature = "metal") {
|
||||||
|
Device::new_metal(0)
|
||||||
|
} else if cfg!(feature = "cuda") {
|
||||||
|
Device::new_cuda(0)
|
||||||
|
} else {
|
||||||
|
Ok(Device::Cpu)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn bench_name<S: Into<String>>(name: S) -> String {
|
||||||
|
format!("{}_{}", device_variant(), name.into())
|
||||||
|
}
|
||||||
|
|
||||||
|
const fn device_variant() -> &'static str {
|
||||||
|
if cfg!(feature = "metal") {
|
||||||
|
"metal"
|
||||||
|
} else if cfg!(feature = "cuda") {
|
||||||
|
"cuda"
|
||||||
|
} else if cfg!(feature = "accelerate") {
|
||||||
|
"accelerate"
|
||||||
|
} else if cfg!(feature = "mkl") {
|
||||||
|
"mkl"
|
||||||
|
} else {
|
||||||
|
"cpu"
|
||||||
|
}
|
||||||
|
}
|
@ -354,7 +354,7 @@ impl BackendStorage for MetalStorage {
|
|||||||
let name = match self.dtype {
|
let name = match self.dtype {
|
||||||
DType::F32 => "affine_f32",
|
DType::F32 => "affine_f32",
|
||||||
DType::F16 => "affine_f16",
|
DType::F16 => "affine_f16",
|
||||||
dtype => crate::bail!("Affine {dtype:?}"),
|
dtype => crate::bail!("Metal contiguous affine {dtype:?} not implemented"),
|
||||||
};
|
};
|
||||||
candle_metal_kernels::call_affine(
|
candle_metal_kernels::call_affine(
|
||||||
&device.device,
|
&device.device,
|
||||||
@ -372,7 +372,7 @@ impl BackendStorage for MetalStorage {
|
|||||||
let name = match self.dtype {
|
let name = match self.dtype {
|
||||||
DType::F32 => "affine_f32_strided",
|
DType::F32 => "affine_f32_strided",
|
||||||
DType::F16 => "affine_f16_strided",
|
DType::F16 => "affine_f16_strided",
|
||||||
dtype => crate::bail!("Affine {dtype:?}"),
|
dtype => crate::bail!("Metal strided affine {dtype:?} not implemented"),
|
||||||
};
|
};
|
||||||
candle_metal_kernels::call_affine_strided(
|
candle_metal_kernels::call_affine_strided(
|
||||||
&device.device,
|
&device.device,
|
||||||
@ -405,7 +405,7 @@ impl BackendStorage for MetalStorage {
|
|||||||
let name = match self.dtype {
|
let name = match self.dtype {
|
||||||
DType::F32 => "powf_f32",
|
DType::F32 => "powf_f32",
|
||||||
DType::F16 => "powf_f16",
|
DType::F16 => "powf_f16",
|
||||||
dtype => crate::bail!("Powf {dtype:?}"),
|
dtype => crate::bail!("Metal contiguous powf {dtype:?} not implemented"),
|
||||||
};
|
};
|
||||||
candle_metal_kernels::call_powf(
|
candle_metal_kernels::call_powf(
|
||||||
&device.device,
|
&device.device,
|
||||||
@ -422,7 +422,7 @@ impl BackendStorage for MetalStorage {
|
|||||||
let name = match self.dtype {
|
let name = match self.dtype {
|
||||||
DType::F32 => "powf_f32_strided",
|
DType::F32 => "powf_f32_strided",
|
||||||
DType::F16 => "powf_f16_strided",
|
DType::F16 => "powf_f16_strided",
|
||||||
dtype => crate::bail!("Powf {dtype:?}"),
|
dtype => crate::bail!("Metal strided powf {dtype:?} not implemented"),
|
||||||
};
|
};
|
||||||
candle_metal_kernels::call_powf_strided(
|
candle_metal_kernels::call_powf_strided(
|
||||||
&device.device,
|
&device.device,
|
||||||
@ -454,7 +454,7 @@ impl BackendStorage for MetalStorage {
|
|||||||
let name = match self.dtype {
|
let name = match self.dtype {
|
||||||
DType::F32 => "elu_f32",
|
DType::F32 => "elu_f32",
|
||||||
DType::F16 => "elu_f16",
|
DType::F16 => "elu_f16",
|
||||||
dtype => crate::bail!("Powf {dtype:?}"),
|
dtype => crate::bail!("Metal contiguous elu {dtype:?} not implemented"),
|
||||||
};
|
};
|
||||||
candle_metal_kernels::call_elu(
|
candle_metal_kernels::call_elu(
|
||||||
&device.device,
|
&device.device,
|
||||||
@ -471,7 +471,7 @@ impl BackendStorage for MetalStorage {
|
|||||||
let name = match self.dtype {
|
let name = match self.dtype {
|
||||||
DType::F32 => "elu_f32_strided",
|
DType::F32 => "elu_f32_strided",
|
||||||
DType::F16 => "elu_f16_strided",
|
DType::F16 => "elu_f16_strided",
|
||||||
dtype => crate::bail!("Powf {dtype:?}"),
|
dtype => crate::bail!("Metal strided elu {dtype:?} not implemented"),
|
||||||
};
|
};
|
||||||
candle_metal_kernels::call_elu_strided(
|
candle_metal_kernels::call_elu_strided(
|
||||||
&device.device,
|
&device.device,
|
||||||
@ -533,7 +533,17 @@ impl BackendStorage for MetalStorage {
|
|||||||
(ReduceOp::Max, DType::BF16) => ("fast_max_bf16_strided", true, false),
|
(ReduceOp::Max, DType::BF16) => ("fast_max_bf16_strided", true, false),
|
||||||
(ReduceOp::ArgMin, DType::BF16) => ("fast_argmin_bf16_strided", true, true),
|
(ReduceOp::ArgMin, DType::BF16) => ("fast_argmin_bf16_strided", true, true),
|
||||||
(ReduceOp::ArgMax, DType::BF16) => ("fast_argmax_bf16_strided", true, true),
|
(ReduceOp::ArgMax, DType::BF16) => ("fast_argmax_bf16_strided", true, true),
|
||||||
(k, dtype) => crate::bail!("Reduce op for non float {k:?} {dtype:?}"),
|
(ReduceOp::Sum, DType::I64) => ("fast_sum_i64_strided", false, false),
|
||||||
|
(ReduceOp::Min, DType::I64) => ("fast_min_i64_strided", true, false),
|
||||||
|
(ReduceOp::Max, DType::I64) => ("fast_max_i64_strided", true, false),
|
||||||
|
(ReduceOp::ArgMin, DType::I64) => ("fast_argmin_i64_strided", true, true),
|
||||||
|
(ReduceOp::ArgMax, DType::I64) => ("fast_argmax_i64_strided", true, true),
|
||||||
|
(ReduceOp::Sum, DType::U8) => ("fast_sum_u8_strided", false, false),
|
||||||
|
(ReduceOp::Min, DType::U8) => ("fast_min_u8_strided", true, false),
|
||||||
|
(ReduceOp::Max, DType::U8) => ("fast_max_u8_strided", true, false),
|
||||||
|
(ReduceOp::ArgMin, DType::U8) => ("fast_argmin_u8_strided", true, true),
|
||||||
|
(ReduceOp::ArgMax, DType::U8) => ("fast_argmax_u8_strided", true, true),
|
||||||
|
(k, dtype) => crate::bail!("Metal reduce op {k:?} {dtype:?} not implemented"),
|
||||||
};
|
};
|
||||||
if check_empty && layout.shape().elem_count() == 0 {
|
if check_empty && layout.shape().elem_count() == 0 {
|
||||||
Err(crate::Error::EmptyTensor { op: "reduce" }.bt())?
|
Err(crate::Error::EmptyTensor { op: "reduce" }.bt())?
|
||||||
@ -580,11 +590,18 @@ impl BackendStorage for MetalStorage {
|
|||||||
let kernel_name = match (self.dtype, dtype) {
|
let kernel_name = match (self.dtype, dtype) {
|
||||||
(DType::U32, DType::F32) => "cast_u32_f32",
|
(DType::U32, DType::F32) => "cast_u32_f32",
|
||||||
(DType::U32, DType::U8) => "cast_u32_u8",
|
(DType::U32, DType::U8) => "cast_u32_u8",
|
||||||
|
(DType::U32, DType::I64) => "cast_u32_i64",
|
||||||
(DType::U8, DType::U32) => "cast_u8_u32",
|
(DType::U8, DType::U32) => "cast_u8_u32",
|
||||||
(DType::U8, DType::F32) => "cast_u8_f32",
|
(DType::U8, DType::F32) => "cast_u8_f32",
|
||||||
|
(DType::U8, DType::I64) => "cast_u8_i64",
|
||||||
(DType::F32, DType::F16) => "cast_f32_f16",
|
(DType::F32, DType::F16) => "cast_f32_f16",
|
||||||
(DType::F16, DType::F32) => "cast_f16_f32",
|
(DType::F16, DType::F32) => "cast_f16_f32",
|
||||||
(left, right) => crate::bail!("to dtype {left:?} - {right:?}"),
|
(DType::I64, DType::F32) => "cast_i64_f32",
|
||||||
|
(DType::F32, DType::BF16) => "cast_f32_bf16",
|
||||||
|
(DType::BF16, DType::F32) => "cast_bf16_f32",
|
||||||
|
(left, right) => {
|
||||||
|
crate::bail!("Metal contiguous to_dtype {left:?} {right:?} not implemented")
|
||||||
|
}
|
||||||
};
|
};
|
||||||
candle_metal_kernels::call_cast_contiguous(
|
candle_metal_kernels::call_cast_contiguous(
|
||||||
&device.device,
|
&device.device,
|
||||||
@ -601,11 +618,18 @@ impl BackendStorage for MetalStorage {
|
|||||||
let kernel_name = match (self.dtype, dtype) {
|
let kernel_name = match (self.dtype, dtype) {
|
||||||
(DType::U32, DType::F32) => "cast_u32_f32_strided",
|
(DType::U32, DType::F32) => "cast_u32_f32_strided",
|
||||||
(DType::U32, DType::U8) => "cast_u32_u8_strided",
|
(DType::U32, DType::U8) => "cast_u32_u8_strided",
|
||||||
|
(DType::U32, DType::I64) => "cast_u32_i64_strided",
|
||||||
(DType::U8, DType::U32) => "cast_u8_u32_strided",
|
(DType::U8, DType::U32) => "cast_u8_u32_strided",
|
||||||
(DType::U8, DType::F32) => "cast_u8_f32_strided",
|
(DType::U8, DType::F32) => "cast_u8_f32_strided",
|
||||||
|
(DType::U8, DType::I64) => "cast_u8_i64_strided",
|
||||||
(DType::F32, DType::F16) => "cast_f32_f16_strided",
|
(DType::F32, DType::F16) => "cast_f32_f16_strided",
|
||||||
(DType::F16, DType::F32) => "cast_f16_f32_strided",
|
(DType::F16, DType::F32) => "cast_f16_f32_strided",
|
||||||
(left, right) => crate::bail!("to dtype {left:?} - {right:?}"),
|
(DType::I64, DType::F32) => "cast_i64_f32_strided",
|
||||||
|
(DType::F32, DType::BF16) => "cast_f32_bf16_strided",
|
||||||
|
(DType::BF16, DType::F32) => "cast_bf16_f32_strided",
|
||||||
|
(left, right) => {
|
||||||
|
crate::bail!("Metal strided to_dtype {left:?} {right:?} not implemented")
|
||||||
|
}
|
||||||
};
|
};
|
||||||
candle_metal_kernels::call_cast_strided(
|
candle_metal_kernels::call_cast_strided(
|
||||||
&device.device,
|
&device.device,
|
||||||
@ -646,9 +670,11 @@ impl BackendStorage for MetalStorage {
|
|||||||
("ugelu", DType::F32) => contiguous::gelu::FLOAT,
|
("ugelu", DType::F32) => contiguous::gelu::FLOAT,
|
||||||
("ugelu_erf", DType::F32) => contiguous::gelu_erf::FLOAT,
|
("ugelu_erf", DType::F32) => contiguous::gelu_erf::FLOAT,
|
||||||
("uerf", DType::F32) => contiguous::erf::FLOAT,
|
("uerf", DType::F32) => contiguous::erf::FLOAT,
|
||||||
|
("uabs", DType::F32) => contiguous::abs::FLOAT,
|
||||||
("uceil", DType::F32) => contiguous::ceil::FLOAT,
|
("uceil", DType::F32) => contiguous::ceil::FLOAT,
|
||||||
("ufloor", DType::F32) => contiguous::floor::FLOAT,
|
("ufloor", DType::F32) => contiguous::floor::FLOAT,
|
||||||
("uround", DType::F32) => contiguous::round::FLOAT,
|
("uround", DType::F32) => contiguous::round::FLOAT,
|
||||||
|
("urecip", DType::F32) => contiguous::recip::FLOAT,
|
||||||
("utanh", DType::F32) => contiguous::tanh::FLOAT,
|
("utanh", DType::F32) => contiguous::tanh::FLOAT,
|
||||||
("ucos", DType::F16) => contiguous::cos::HALF,
|
("ucos", DType::F16) => contiguous::cos::HALF,
|
||||||
("usin", DType::F16) => contiguous::sin::HALF,
|
("usin", DType::F16) => contiguous::sin::HALF,
|
||||||
@ -660,11 +686,15 @@ impl BackendStorage for MetalStorage {
|
|||||||
("ugelu", DType::F16) => contiguous::gelu::HALF,
|
("ugelu", DType::F16) => contiguous::gelu::HALF,
|
||||||
("ugelu_erf", DType::F16) => contiguous::gelu_erf::HALF,
|
("ugelu_erf", DType::F16) => contiguous::gelu_erf::HALF,
|
||||||
("uerf", DType::F16) => contiguous::erf::HALF,
|
("uerf", DType::F16) => contiguous::erf::HALF,
|
||||||
|
("uabs", DType::F16) => contiguous::abs::HALF,
|
||||||
("uceil", DType::F16) => contiguous::ceil::HALF,
|
("uceil", DType::F16) => contiguous::ceil::HALF,
|
||||||
("ufloor", DType::F16) => contiguous::floor::HALF,
|
("ufloor", DType::F16) => contiguous::floor::HALF,
|
||||||
("uround", DType::F16) => contiguous::round::HALF,
|
("uround", DType::F16) => contiguous::round::HALF,
|
||||||
|
("urecip", DType::F16) => contiguous::recip::HALF,
|
||||||
("utanh", DType::F16) => contiguous::tanh::HALF,
|
("utanh", DType::F16) => contiguous::tanh::HALF,
|
||||||
(name, dtype) => crate::bail!("Match {name} - {dtype:?}"),
|
(name, dtype) => {
|
||||||
|
crate::bail!("Metal contiguous unary {name} {dtype:?} not implemented")
|
||||||
|
}
|
||||||
};
|
};
|
||||||
candle_metal_kernels::call_unary_contiguous(
|
candle_metal_kernels::call_unary_contiguous(
|
||||||
&device.device,
|
&device.device,
|
||||||
@ -689,6 +719,7 @@ impl BackendStorage for MetalStorage {
|
|||||||
("ugelu", DType::F32) => strided::gelu::FLOAT,
|
("ugelu", DType::F32) => strided::gelu::FLOAT,
|
||||||
("ugelu_erf", DType::F32) => strided::gelu_erf::FLOAT,
|
("ugelu_erf", DType::F32) => strided::gelu_erf::FLOAT,
|
||||||
("uerf", DType::F32) => strided::erf::FLOAT,
|
("uerf", DType::F32) => strided::erf::FLOAT,
|
||||||
|
("uabs", DType::F32) => strided::abs::FLOAT,
|
||||||
("uceil", DType::F32) => strided::ceil::FLOAT,
|
("uceil", DType::F32) => strided::ceil::FLOAT,
|
||||||
("ufloor", DType::F32) => strided::floor::FLOAT,
|
("ufloor", DType::F32) => strided::floor::FLOAT,
|
||||||
("uround", DType::F32) => strided::round::FLOAT,
|
("uround", DType::F32) => strided::round::FLOAT,
|
||||||
@ -702,10 +733,13 @@ impl BackendStorage for MetalStorage {
|
|||||||
("ugelu", DType::F16) => strided::gelu::HALF,
|
("ugelu", DType::F16) => strided::gelu::HALF,
|
||||||
("ugelu_erf", DType::F16) => strided::gelu_erf::HALF,
|
("ugelu_erf", DType::F16) => strided::gelu_erf::HALF,
|
||||||
("uerf", DType::F16) => strided::erf::HALF,
|
("uerf", DType::F16) => strided::erf::HALF,
|
||||||
|
("uabs", DType::F16) => strided::abs::HALF,
|
||||||
("uceil", DType::F16) => strided::ceil::HALF,
|
("uceil", DType::F16) => strided::ceil::HALF,
|
||||||
("ufloor", DType::F16) => strided::floor::HALF,
|
("ufloor", DType::F16) => strided::floor::HALF,
|
||||||
("uround", DType::F16) => strided::round::HALF,
|
("uround", DType::F16) => strided::round::HALF,
|
||||||
(name, dtype) => crate::bail!("Match {name} - {dtype:?}"),
|
(name, dtype) => {
|
||||||
|
crate::bail!("Metal strided unary {name} {dtype:?} not implemented")
|
||||||
|
}
|
||||||
};
|
};
|
||||||
candle_metal_kernels::call_unary_strided(
|
candle_metal_kernels::call_unary_strided(
|
||||||
&device.device,
|
&device.device,
|
||||||
@ -758,7 +792,10 @@ impl BackendStorage for MetalStorage {
|
|||||||
let name = match (self.dtype, t.dtype()) {
|
let name = match (self.dtype, t.dtype()) {
|
||||||
(DType::U8, DType::F32) => "where_u8_f32",
|
(DType::U8, DType::F32) => "where_u8_f32",
|
||||||
(DType::U8, DType::F16) => "where_u8_f16",
|
(DType::U8, DType::F16) => "where_u8_f16",
|
||||||
(left, right) => crate::bail!("where {left:?} - {right:?} not implemented"),
|
(DType::U8, DType::I64) => "where_u8_i64",
|
||||||
|
(DType::U8, DType::U32) => "where_u8_u32",
|
||||||
|
(DType::U8, DType::U8) => "where_u8_u8",
|
||||||
|
(left, right) => crate::bail!("Metal where_cond {left:?} {right:?} not implemented"),
|
||||||
};
|
};
|
||||||
candle_metal_kernels::call_where_cond_strided(
|
candle_metal_kernels::call_where_cond_strided(
|
||||||
&device.device,
|
&device.device,
|
||||||
@ -805,7 +842,7 @@ impl BackendStorage for MetalStorage {
|
|||||||
let command_buffer = self.device.command_buffer()?;
|
let command_buffer = self.device.command_buffer()?;
|
||||||
let name = match self.dtype {
|
let name = match self.dtype {
|
||||||
DType::F32 => "im2col1d_f32",
|
DType::F32 => "im2col1d_f32",
|
||||||
dtype => crate::bail!("conv1d metal {dtype:?} not implemented"),
|
dtype => crate::bail!("Metal conv1d {dtype:?} not implemented"),
|
||||||
};
|
};
|
||||||
candle_metal_kernels::call_im2col1d_strided(
|
candle_metal_kernels::call_im2col1d_strided(
|
||||||
&self.device.device,
|
&self.device.device,
|
||||||
@ -858,7 +895,7 @@ impl BackendStorage for MetalStorage {
|
|||||||
_kernel_l: &Layout,
|
_kernel_l: &Layout,
|
||||||
_params: &ParamsConvTranspose1D,
|
_params: &ParamsConvTranspose1D,
|
||||||
) -> Result<Self> {
|
) -> Result<Self> {
|
||||||
crate::bail!("conv_transpose1d metal")
|
crate::bail!("Metal conv_transpose1d not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
fn conv2d(
|
fn conv2d(
|
||||||
@ -889,7 +926,7 @@ impl BackendStorage for MetalStorage {
|
|||||||
let command_buffer = self.device.command_buffer()?;
|
let command_buffer = self.device.command_buffer()?;
|
||||||
let name = match self.dtype {
|
let name = match self.dtype {
|
||||||
DType::F32 => "im2col_f32",
|
DType::F32 => "im2col_f32",
|
||||||
dtype => crate::bail!("conv1d metal {dtype:?} not implemented"),
|
dtype => crate::bail!("Metal conv2d {dtype:?} not implemented"),
|
||||||
};
|
};
|
||||||
candle_metal_kernels::call_im2col_strided(
|
candle_metal_kernels::call_im2col_strided(
|
||||||
&self.device.device,
|
&self.device.device,
|
||||||
@ -945,19 +982,19 @@ impl BackendStorage for MetalStorage {
|
|||||||
_kernel_l: &Layout,
|
_kernel_l: &Layout,
|
||||||
_params: &ParamsConvTranspose2D,
|
_params: &ParamsConvTranspose2D,
|
||||||
) -> Result<Self> {
|
) -> Result<Self> {
|
||||||
crate::bail!("conv_tranpose2d metal")
|
crate::bail!("Metal conv_tranpose2d not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
fn avg_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result<Self> {
|
fn avg_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result<Self> {
|
||||||
crate::bail!("avg_pool2d metal")
|
crate::bail!("Metal avg_pool2d not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
fn max_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result<Self> {
|
fn max_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result<Self> {
|
||||||
crate::bail!("max_pool2d metal")
|
crate::bail!("Metal max_pool2d not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
fn upsample_nearest1d(&self, _: &Layout, _: usize) -> Result<Self> {
|
fn upsample_nearest1d(&self, _: &Layout, _: usize) -> Result<Self> {
|
||||||
crate::bail!("upsample_nearest1d metal")
|
crate::bail!("Metal upsample_nearest1d not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
fn upsample_nearest2d(&self, inp_l: &Layout, out_w: usize, out_h: usize) -> Result<Self> {
|
fn upsample_nearest2d(&self, inp_l: &Layout, out_w: usize, out_h: usize) -> Result<Self> {
|
||||||
@ -970,7 +1007,7 @@ impl BackendStorage for MetalStorage {
|
|||||||
}
|
}
|
||||||
let name = match self.dtype {
|
let name = match self.dtype {
|
||||||
DType::F32 => "upsample_nearest2d_f32",
|
DType::F32 => "upsample_nearest2d_f32",
|
||||||
dtype => crate::bail!("Not implemented {dtype:?} for upsample_nearest2d, metal"),
|
dtype => crate::bail!("Metal upsample_nearest2d {dtype:?} not implemented"),
|
||||||
};
|
};
|
||||||
|
|
||||||
let dst_el = out_w * out_h * dims[0] * dims[1];
|
let dst_el = out_w * out_h * dims[0] * dims[1];
|
||||||
@ -1008,7 +1045,7 @@ impl BackendStorage for MetalStorage {
|
|||||||
let name = match (ids.dtype, self.dtype) {
|
let name = match (ids.dtype, self.dtype) {
|
||||||
(DType::U32, DType::F32) => "gather_u32_f32",
|
(DType::U32, DType::F32) => "gather_u32_f32",
|
||||||
(DType::U32, DType::F16) => "gather_u32_f16",
|
(DType::U32, DType::F16) => "gather_u32_f16",
|
||||||
(left, right) => crate::bail!("gather metal {left:?} {right:?} not implemented"),
|
(left, right) => crate::bail!("Metal gather {left:?} {right:?} not implemented"),
|
||||||
};
|
};
|
||||||
let command_buffer = self.device.command_buffer()?;
|
let command_buffer = self.device.command_buffer()?;
|
||||||
candle_metal_kernels::call_gather(
|
candle_metal_kernels::call_gather(
|
||||||
@ -1081,7 +1118,7 @@ impl BackendStorage for MetalStorage {
|
|||||||
&& ids_l.is_contiguous()
|
&& ids_l.is_contiguous()
|
||||||
&& ids_l.start_offset() == 0)
|
&& ids_l.start_offset() == 0)
|
||||||
{
|
{
|
||||||
crate::bail!("Non contiguous index select not implemented");
|
crate::bail!("Metal strided index_select not implemented");
|
||||||
}
|
}
|
||||||
let left_size: usize = src_l.dims()[..dim].iter().product();
|
let left_size: usize = src_l.dims()[..dim].iter().product();
|
||||||
let right_size: usize = src_l.dims()[dim + 1..].iter().product();
|
let right_size: usize = src_l.dims()[dim + 1..].iter().product();
|
||||||
@ -1093,7 +1130,9 @@ impl BackendStorage for MetalStorage {
|
|||||||
let name = match (ids.dtype, self.dtype) {
|
let name = match (ids.dtype, self.dtype) {
|
||||||
(DType::U32, DType::F32) => "is_u32_f32",
|
(DType::U32, DType::F32) => "is_u32_f32",
|
||||||
(DType::U32, DType::F16) => "is_u32_f16",
|
(DType::U32, DType::F16) => "is_u32_f16",
|
||||||
(left, right) => crate::bail!("index select metal {left:?} {right:?}"),
|
(left, right) => {
|
||||||
|
crate::bail!("Metal contiguous index_select {left:?} {right:?} not implemented")
|
||||||
|
}
|
||||||
};
|
};
|
||||||
let command_buffer = self.device.command_buffer()?;
|
let command_buffer = self.device.command_buffer()?;
|
||||||
candle_metal_kernels::call_index_select(
|
candle_metal_kernels::call_index_select(
|
||||||
@ -1134,7 +1173,7 @@ impl BackendStorage for MetalStorage {
|
|||||||
let name = match (ids.dtype, self.dtype) {
|
let name = match (ids.dtype, self.dtype) {
|
||||||
(DType::U32, DType::F32) => "ia_u32_f32",
|
(DType::U32, DType::F32) => "ia_u32_f32",
|
||||||
_ => Err(MetalError::UnexpectedDType {
|
_ => Err(MetalError::UnexpectedDType {
|
||||||
msg: "index-add ids should be u8/u32/i64",
|
msg: "index-add ids should be u32",
|
||||||
expected: DType::U32,
|
expected: DType::U32,
|
||||||
got: ids.dtype(),
|
got: ids.dtype(),
|
||||||
})?,
|
})?,
|
||||||
@ -1215,9 +1254,10 @@ impl BackendStorage for MetalStorage {
|
|||||||
DType::F32 => candle_metal_kernels::unary::strided::copy::FLOAT,
|
DType::F32 => candle_metal_kernels::unary::strided::copy::FLOAT,
|
||||||
DType::F16 => candle_metal_kernels::unary::strided::copy::HALF,
|
DType::F16 => candle_metal_kernels::unary::strided::copy::HALF,
|
||||||
DType::BF16 => candle_metal_kernels::unary::strided::copy::BFLOAT,
|
DType::BF16 => candle_metal_kernels::unary::strided::copy::BFLOAT,
|
||||||
|
DType::I64 => candle_metal_kernels::unary::strided::copy::I64,
|
||||||
DType::U32 => candle_metal_kernels::unary::strided::copy::U32,
|
DType::U32 => candle_metal_kernels::unary::strided::copy::U32,
|
||||||
DType::U8 => candle_metal_kernels::unary::strided::copy::U8,
|
DType::U8 => candle_metal_kernels::unary::strided::copy::U8,
|
||||||
dtype => crate::bail!("copy_strided not implemented for {dtype:?}"),
|
dtype => crate::bail!("Metal copy_strided {dtype:?} not implemented"),
|
||||||
};
|
};
|
||||||
candle_metal_kernels::call_unary_strided(
|
candle_metal_kernels::call_unary_strided(
|
||||||
&self.device.device,
|
&self.device.device,
|
||||||
@ -1289,7 +1329,39 @@ impl MetalStorage {
|
|||||||
("lt", DType::F16) => (contiguous::lt::HALF, DType::U8),
|
("lt", DType::F16) => (contiguous::lt::HALF, DType::U8),
|
||||||
("ge", DType::F16) => (contiguous::ge::HALF, DType::U8),
|
("ge", DType::F16) => (contiguous::ge::HALF, DType::U8),
|
||||||
("gt", DType::F16) => (contiguous::gt::HALF, DType::U8),
|
("gt", DType::F16) => (contiguous::gt::HALF, DType::U8),
|
||||||
(name, dtype) => crate::bail!("Binary {name} - {dtype:?} not implemented"),
|
("add", DType::I64) => (contiguous::add::I64, self.dtype),
|
||||||
|
("sub", DType::I64) => (contiguous::sub::I64, self.dtype),
|
||||||
|
("mul", DType::I64) => (contiguous::mul::I64, self.dtype),
|
||||||
|
("div", DType::I64) => (contiguous::div::I64, self.dtype),
|
||||||
|
("eq", DType::I64) => (contiguous::eq::I64, DType::U8),
|
||||||
|
("ne", DType::I64) => (contiguous::ne::I64, DType::U8),
|
||||||
|
("le", DType::I64) => (contiguous::le::I64, DType::U8),
|
||||||
|
("lt", DType::I64) => (contiguous::lt::I64, DType::U8),
|
||||||
|
("ge", DType::I64) => (contiguous::ge::I64, DType::U8),
|
||||||
|
("gt", DType::I64) => (contiguous::gt::I64, DType::U8),
|
||||||
|
("add", DType::U32) => (contiguous::add::U32, self.dtype),
|
||||||
|
("sub", DType::U32) => (contiguous::sub::U32, self.dtype),
|
||||||
|
("mul", DType::U32) => (contiguous::mul::U32, self.dtype),
|
||||||
|
("div", DType::U32) => (contiguous::div::U32, self.dtype),
|
||||||
|
("eq", DType::U32) => (contiguous::eq::U32, DType::U8),
|
||||||
|
("ne", DType::U32) => (contiguous::ne::U32, DType::U8),
|
||||||
|
("le", DType::U32) => (contiguous::le::U32, DType::U8),
|
||||||
|
("lt", DType::U32) => (contiguous::lt::U32, DType::U8),
|
||||||
|
("ge", DType::U32) => (contiguous::ge::U32, DType::U8),
|
||||||
|
("gt", DType::U32) => (contiguous::gt::U32, DType::U8),
|
||||||
|
("add", DType::U8) => (contiguous::add::U8, self.dtype),
|
||||||
|
("sub", DType::U8) => (contiguous::sub::U8, self.dtype),
|
||||||
|
("mul", DType::U8) => (contiguous::mul::U8, self.dtype),
|
||||||
|
("div", DType::U8) => (contiguous::div::U8, self.dtype),
|
||||||
|
("eq", DType::U8) => (contiguous::eq::U8, DType::U8),
|
||||||
|
("ne", DType::U8) => (contiguous::ne::U8, DType::U8),
|
||||||
|
("le", DType::U8) => (contiguous::le::U8, DType::U8),
|
||||||
|
("lt", DType::U8) => (contiguous::lt::U8, DType::U8),
|
||||||
|
("ge", DType::U8) => (contiguous::ge::U8, DType::U8),
|
||||||
|
("gt", DType::U8) => (contiguous::gt::U8, DType::U8),
|
||||||
|
(name, dtype) => {
|
||||||
|
crate::bail!("Metal contiguous binary {name} {dtype:?} not implemented")
|
||||||
|
}
|
||||||
};
|
};
|
||||||
let buffer = device.new_buffer(el_count, dtype, op)?;
|
let buffer = device.new_buffer(el_count, dtype, op)?;
|
||||||
candle_metal_kernels::call_binary_contiguous(
|
candle_metal_kernels::call_binary_contiguous(
|
||||||
@ -1332,7 +1404,45 @@ impl MetalStorage {
|
|||||||
("lt", DType::F16) => (strided::lt::HALF, DType::U8),
|
("lt", DType::F16) => (strided::lt::HALF, DType::U8),
|
||||||
("ge", DType::F16) => (strided::ge::HALF, DType::U8),
|
("ge", DType::F16) => (strided::ge::HALF, DType::U8),
|
||||||
("gt", DType::F16) => (strided::gt::HALF, DType::U8),
|
("gt", DType::F16) => (strided::gt::HALF, DType::U8),
|
||||||
(name, dtype) => crate::bail!("Binary strided {name} - {dtype:?} not implemented"),
|
("badd", DType::I64) => (strided::add::I64, self.dtype),
|
||||||
|
("bsub", DType::I64) => (strided::sub::I64, self.dtype),
|
||||||
|
("bmul", DType::I64) => (strided::mul::I64, self.dtype),
|
||||||
|
("bdiv", DType::I64) => (strided::div::I64, self.dtype),
|
||||||
|
("bminimum", DType::I64) => (strided::min::I64, self.dtype),
|
||||||
|
("bmaximum", DType::I64) => (strided::max::I64, self.dtype),
|
||||||
|
("eq", DType::I64) => (strided::eq::I64, DType::U8),
|
||||||
|
("ne", DType::I64) => (strided::ne::I64, DType::U8),
|
||||||
|
("le", DType::I64) => (strided::le::I64, DType::U8),
|
||||||
|
("lt", DType::I64) => (strided::lt::I64, DType::U8),
|
||||||
|
("ge", DType::I64) => (strided::ge::I64, DType::U8),
|
||||||
|
("gt", DType::I64) => (strided::gt::I64, DType::U8),
|
||||||
|
("badd", DType::U32) => (strided::add::U32, self.dtype),
|
||||||
|
("bsub", DType::U32) => (strided::sub::U32, self.dtype),
|
||||||
|
("bmul", DType::U32) => (strided::mul::U32, self.dtype),
|
||||||
|
("bdiv", DType::U32) => (strided::div::U32, self.dtype),
|
||||||
|
("bminimum", DType::U32) => (strided::min::U32, self.dtype),
|
||||||
|
("bmaximum", DType::U32) => (strided::max::U32, self.dtype),
|
||||||
|
("eq", DType::U32) => (strided::eq::U32, DType::U8),
|
||||||
|
("ne", DType::U32) => (strided::ne::U32, DType::U8),
|
||||||
|
("le", DType::U32) => (strided::le::U32, DType::U8),
|
||||||
|
("lt", DType::U32) => (strided::lt::U32, DType::U8),
|
||||||
|
("ge", DType::U32) => (strided::ge::U32, DType::U8),
|
||||||
|
("gt", DType::U32) => (strided::gt::U32, DType::U8),
|
||||||
|
("badd", DType::U8) => (strided::add::U8, self.dtype),
|
||||||
|
("bsub", DType::U8) => (strided::sub::U8, self.dtype),
|
||||||
|
("bmul", DType::U8) => (strided::mul::U8, self.dtype),
|
||||||
|
("bdiv", DType::U8) => (strided::div::U8, self.dtype),
|
||||||
|
("bminimum", DType::U8) => (strided::min::U8, self.dtype),
|
||||||
|
("bmaximum", DType::U8) => (strided::max::U8, self.dtype),
|
||||||
|
("eq", DType::U8) => (strided::eq::U8, DType::U8),
|
||||||
|
("ne", DType::U8) => (strided::ne::U8, DType::U8),
|
||||||
|
("le", DType::U8) => (strided::le::U8, DType::U8),
|
||||||
|
("lt", DType::U8) => (strided::lt::U8, DType::U8),
|
||||||
|
("ge", DType::U8) => (strided::ge::U8, DType::U8),
|
||||||
|
("gt", DType::U8) => (strided::gt::U8, DType::U8),
|
||||||
|
(name, dtype) => {
|
||||||
|
crate::bail!("Metal strided binary {name} {dtype:?} not implemented")
|
||||||
|
}
|
||||||
};
|
};
|
||||||
let buffer = device.new_buffer(el_count, dtype, op)?;
|
let buffer = device.new_buffer(el_count, dtype, op)?;
|
||||||
candle_metal_kernels::call_binary_strided(
|
candle_metal_kernels::call_binary_strided(
|
||||||
@ -1387,7 +1497,7 @@ impl BackendDevice for MetalDevice {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn set_seed(&self, _seed: u64) -> Result<()> {
|
fn set_seed(&self, _seed: u64) -> Result<()> {
|
||||||
crate::bail!("set_seed")
|
crate::bail!("Metal set_seed not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
fn location(&self) -> crate::DeviceLocation {
|
fn location(&self) -> crate::DeviceLocation {
|
||||||
|
@ -33,6 +33,8 @@ enum Which {
|
|||||||
V2,
|
V2,
|
||||||
#[value(name = "solar-10.7b")]
|
#[value(name = "solar-10.7b")]
|
||||||
Solar10_7B,
|
Solar10_7B,
|
||||||
|
#[value(name = "tiny-llama-1.1b-chat")]
|
||||||
|
TinyLlama1_1BChat,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Parser, Debug)]
|
#[derive(Parser, Debug)]
|
||||||
@ -124,6 +126,7 @@ fn main() -> Result<()> {
|
|||||||
Which::V1 => "Narsil/amall-7b".to_string(),
|
Which::V1 => "Narsil/amall-7b".to_string(),
|
||||||
Which::V2 => "meta-llama/Llama-2-7b-hf".to_string(),
|
Which::V2 => "meta-llama/Llama-2-7b-hf".to_string(),
|
||||||
Which::Solar10_7B => "upstage/SOLAR-10.7B-v1.0".to_string(),
|
Which::Solar10_7B => "upstage/SOLAR-10.7B-v1.0".to_string(),
|
||||||
|
Which::TinyLlama1_1BChat => "TinyLlama/TinyLlama-1.1B-Chat-v1.0".to_string(),
|
||||||
});
|
});
|
||||||
println!("loading the model weights from {model_id}");
|
println!("loading the model weights from {model_id}");
|
||||||
let revision = args.revision.unwrap_or("main".to_string());
|
let revision = args.revision.unwrap_or("main".to_string());
|
||||||
@ -134,8 +137,12 @@ fn main() -> Result<()> {
|
|||||||
let config: LlamaConfig = serde_json::from_slice(&std::fs::read(config_filename)?)?;
|
let config: LlamaConfig = serde_json::from_slice(&std::fs::read(config_filename)?)?;
|
||||||
let config = config.into_config(args.use_flash_attn);
|
let config = config.into_config(args.use_flash_attn);
|
||||||
|
|
||||||
let filenames =
|
let filenames = match args.which {
|
||||||
candle_examples::hub_load_safetensors(&api, "model.safetensors.index.json")?;
|
Which::V1 | Which::V2 | Which::Solar10_7B => {
|
||||||
|
candle_examples::hub_load_safetensors(&api, "model.safetensors.index.json")?
|
||||||
|
}
|
||||||
|
Which::TinyLlama1_1BChat => vec![api.get("model.safetensors")?],
|
||||||
|
};
|
||||||
println!("building the model");
|
println!("building the model");
|
||||||
let cache = model::Cache::new(!args.no_kv_cache, dtype, &config, &device)?;
|
let cache = model::Cache::new(!args.no_kv_cache, dtype, &config, &device)?;
|
||||||
|
|
||||||
@ -158,14 +165,14 @@ fn main() -> Result<()> {
|
|||||||
let mut index_pos = 0;
|
let mut index_pos = 0;
|
||||||
let mut token_generated = 0;
|
let mut token_generated = 0;
|
||||||
for index in 0..args.sample_len {
|
for index in 0..args.sample_len {
|
||||||
let context_size = if cache.use_kv_cache && index > 0 {
|
let (context_size, context_index) = if cache.use_kv_cache && index > 0 {
|
||||||
1
|
(1, index_pos)
|
||||||
} else {
|
} else {
|
||||||
tokens.len()
|
(tokens.len(), 0)
|
||||||
};
|
};
|
||||||
let ctxt = &tokens[tokens.len().saturating_sub(context_size)..];
|
let ctxt = &tokens[tokens.len().saturating_sub(context_size)..];
|
||||||
let input = Tensor::new(ctxt, &device)?.unsqueeze(0)?;
|
let input = Tensor::new(ctxt, &device)?.unsqueeze(0)?;
|
||||||
let logits = llama.forward(&input, index_pos)?;
|
let logits = llama.forward(&input, context_index)?;
|
||||||
let logits = logits.squeeze(0)?;
|
let logits = logits.squeeze(0)?;
|
||||||
let logits = if args.repeat_penalty == 1. {
|
let logits = if args.repeat_penalty == 1. {
|
||||||
logits
|
logits
|
||||||
|
@ -8,9 +8,16 @@ Python package with:
|
|||||||
pip install "gymnasium[accept-rom-license]"
|
pip install "gymnasium[accept-rom-license]"
|
||||||
```
|
```
|
||||||
|
|
||||||
In order to run the example, use the following command. Note the additional
|
In order to run the examples, use the following commands. Note the additional
|
||||||
`--package` flag to ensure that there is no conflict with the `candle-pyo3`
|
`--package` flag to ensure that there is no conflict with the `candle-pyo3`
|
||||||
crate.
|
crate.
|
||||||
|
|
||||||
|
For the Policy Gradient example:
|
||||||
```bash
|
```bash
|
||||||
cargo run --example reinforcement-learning --features=pyo3 --package candle-examples
|
cargo run --example reinforcement-learning --features=pyo3 --package candle-examples -- pg
|
||||||
|
```
|
||||||
|
|
||||||
|
For the Deep Deterministic Policy Gradient example:
|
||||||
|
```bash
|
||||||
|
cargo run --example reinforcement-learning --features=pyo3 --package candle-examples -- ddpg
|
||||||
```
|
```
|
||||||
|
@ -8,6 +8,8 @@ use candle_nn::{
|
|||||||
};
|
};
|
||||||
use rand::{distributions::Uniform, thread_rng, Rng};
|
use rand::{distributions::Uniform, thread_rng, Rng};
|
||||||
|
|
||||||
|
use super::gym_env::GymEnv;
|
||||||
|
|
||||||
pub struct OuNoise {
|
pub struct OuNoise {
|
||||||
mu: f64,
|
mu: f64,
|
||||||
theta: f64,
|
theta: f64,
|
||||||
@ -449,3 +451,106 @@ impl DDPG<'_> {
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// The impact of the q value of the next state on the current state's q value.
|
||||||
|
const GAMMA: f64 = 0.99;
|
||||||
|
// The weight for updating the target networks.
|
||||||
|
const TAU: f64 = 0.005;
|
||||||
|
// The capacity of the replay buffer used for sampling training data.
|
||||||
|
const REPLAY_BUFFER_CAPACITY: usize = 100_000;
|
||||||
|
// The training batch size for each training iteration.
|
||||||
|
const TRAINING_BATCH_SIZE: usize = 100;
|
||||||
|
// The total number of episodes.
|
||||||
|
const MAX_EPISODES: usize = 100;
|
||||||
|
// The maximum length of an episode.
|
||||||
|
const EPISODE_LENGTH: usize = 200;
|
||||||
|
// The number of training iterations after one episode finishes.
|
||||||
|
const TRAINING_ITERATIONS: usize = 200;
|
||||||
|
|
||||||
|
// Ornstein-Uhlenbeck process parameters.
|
||||||
|
const MU: f64 = 0.0;
|
||||||
|
const THETA: f64 = 0.15;
|
||||||
|
const SIGMA: f64 = 0.1;
|
||||||
|
|
||||||
|
const ACTOR_LEARNING_RATE: f64 = 1e-4;
|
||||||
|
const CRITIC_LEARNING_RATE: f64 = 1e-3;
|
||||||
|
|
||||||
|
pub fn run() -> Result<()> {
|
||||||
|
let env = GymEnv::new("Pendulum-v1")?;
|
||||||
|
println!("action space: {}", env.action_space());
|
||||||
|
println!("observation space: {:?}", env.observation_space());
|
||||||
|
|
||||||
|
let size_state = env.observation_space().iter().product::<usize>();
|
||||||
|
let size_action = env.action_space();
|
||||||
|
|
||||||
|
let mut agent = DDPG::new(
|
||||||
|
&Device::Cpu,
|
||||||
|
size_state,
|
||||||
|
size_action,
|
||||||
|
true,
|
||||||
|
ACTOR_LEARNING_RATE,
|
||||||
|
CRITIC_LEARNING_RATE,
|
||||||
|
GAMMA,
|
||||||
|
TAU,
|
||||||
|
REPLAY_BUFFER_CAPACITY,
|
||||||
|
OuNoise::new(MU, THETA, SIGMA, size_action)?,
|
||||||
|
)?;
|
||||||
|
|
||||||
|
let mut rng = rand::thread_rng();
|
||||||
|
|
||||||
|
for episode in 0..MAX_EPISODES {
|
||||||
|
// let mut state = env.reset(episode as u64)?;
|
||||||
|
let mut state = env.reset(rng.gen::<u64>())?;
|
||||||
|
|
||||||
|
let mut total_reward = 0.0;
|
||||||
|
for _ in 0..EPISODE_LENGTH {
|
||||||
|
let mut action = 2.0 * agent.actions(&state)?;
|
||||||
|
action = action.clamp(-2.0, 2.0);
|
||||||
|
|
||||||
|
let step = env.step(vec![action])?;
|
||||||
|
total_reward += step.reward;
|
||||||
|
|
||||||
|
agent.remember(
|
||||||
|
&state,
|
||||||
|
&Tensor::new(vec![action], &Device::Cpu)?,
|
||||||
|
&Tensor::new(vec![step.reward as f32], &Device::Cpu)?,
|
||||||
|
&step.state,
|
||||||
|
step.terminated,
|
||||||
|
step.truncated,
|
||||||
|
);
|
||||||
|
|
||||||
|
if step.terminated || step.truncated {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
state = step.state;
|
||||||
|
}
|
||||||
|
|
||||||
|
println!("episode {episode} with total reward of {total_reward}");
|
||||||
|
|
||||||
|
for _ in 0..TRAINING_ITERATIONS {
|
||||||
|
agent.train(TRAINING_BATCH_SIZE)?;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
println!("Testing...");
|
||||||
|
agent.train = false;
|
||||||
|
for episode in 0..10 {
|
||||||
|
// let mut state = env.reset(episode as u64)?;
|
||||||
|
let mut state = env.reset(rng.gen::<u64>())?;
|
||||||
|
let mut total_reward = 0.0;
|
||||||
|
for _ in 0..EPISODE_LENGTH {
|
||||||
|
let mut action = 2.0 * agent.actions(&state)?;
|
||||||
|
action = action.clamp(-2.0, 2.0);
|
||||||
|
|
||||||
|
let step = env.step(vec![action])?;
|
||||||
|
total_reward += step.reward;
|
||||||
|
|
||||||
|
if step.terminated || step.truncated {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
state = step.state;
|
||||||
|
}
|
||||||
|
println!("episode {episode} with total reward of {total_reward}");
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
@ -6,139 +6,32 @@ extern crate intel_mkl_src;
|
|||||||
#[cfg(feature = "accelerate")]
|
#[cfg(feature = "accelerate")]
|
||||||
extern crate accelerate_src;
|
extern crate accelerate_src;
|
||||||
|
|
||||||
|
use candle::Result;
|
||||||
|
use clap::{Parser, Subcommand};
|
||||||
|
|
||||||
mod gym_env;
|
mod gym_env;
|
||||||
mod vec_gym_env;
|
mod vec_gym_env;
|
||||||
|
|
||||||
mod ddpg;
|
mod ddpg;
|
||||||
|
mod policy_gradient;
|
||||||
|
|
||||||
use candle::{Device, Result, Tensor};
|
#[derive(Parser)]
|
||||||
use clap::Parser;
|
|
||||||
use rand::Rng;
|
|
||||||
|
|
||||||
// The impact of the q value of the next state on the current state's q value.
|
|
||||||
const GAMMA: f64 = 0.99;
|
|
||||||
// The weight for updating the target networks.
|
|
||||||
const TAU: f64 = 0.005;
|
|
||||||
// The capacity of the replay buffer used for sampling training data.
|
|
||||||
const REPLAY_BUFFER_CAPACITY: usize = 100_000;
|
|
||||||
// The training batch size for each training iteration.
|
|
||||||
const TRAINING_BATCH_SIZE: usize = 100;
|
|
||||||
// The total number of episodes.
|
|
||||||
const MAX_EPISODES: usize = 100;
|
|
||||||
// The maximum length of an episode.
|
|
||||||
const EPISODE_LENGTH: usize = 200;
|
|
||||||
// The number of training iterations after one episode finishes.
|
|
||||||
const TRAINING_ITERATIONS: usize = 200;
|
|
||||||
|
|
||||||
// Ornstein-Uhlenbeck process parameters.
|
|
||||||
const MU: f64 = 0.0;
|
|
||||||
const THETA: f64 = 0.15;
|
|
||||||
const SIGMA: f64 = 0.1;
|
|
||||||
|
|
||||||
const ACTOR_LEARNING_RATE: f64 = 1e-4;
|
|
||||||
const CRITIC_LEARNING_RATE: f64 = 1e-3;
|
|
||||||
|
|
||||||
#[derive(Parser, Debug, Clone)]
|
|
||||||
#[command(author, version, about, long_about = None)]
|
|
||||||
struct Args {
|
struct Args {
|
||||||
/// Run on CPU rather than on GPU.
|
#[command(subcommand)]
|
||||||
#[arg(long)]
|
command: Command,
|
||||||
cpu: bool,
|
}
|
||||||
|
|
||||||
/// Enable tracing (generates a trace-timestamp.json file).
|
#[derive(Subcommand)]
|
||||||
#[arg(long)]
|
enum Command {
|
||||||
tracing: bool,
|
Pg,
|
||||||
|
Ddpg,
|
||||||
}
|
}
|
||||||
|
|
||||||
fn main() -> Result<()> {
|
fn main() -> Result<()> {
|
||||||
use tracing_chrome::ChromeLayerBuilder;
|
|
||||||
use tracing_subscriber::prelude::*;
|
|
||||||
|
|
||||||
let args = Args::parse();
|
let args = Args::parse();
|
||||||
|
match args.command {
|
||||||
let _guard = if args.tracing {
|
Command::Pg => policy_gradient::run()?,
|
||||||
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
|
Command::Ddpg => ddpg::run()?,
|
||||||
tracing_subscriber::registry().with(chrome_layer).init();
|
|
||||||
Some(guard)
|
|
||||||
} else {
|
|
||||||
None
|
|
||||||
};
|
|
||||||
|
|
||||||
let env = gym_env::GymEnv::new("Pendulum-v1")?;
|
|
||||||
println!("action space: {}", env.action_space());
|
|
||||||
println!("observation space: {:?}", env.observation_space());
|
|
||||||
|
|
||||||
let size_state = env.observation_space().iter().product::<usize>();
|
|
||||||
let size_action = env.action_space();
|
|
||||||
|
|
||||||
let mut agent = ddpg::DDPG::new(
|
|
||||||
&Device::Cpu,
|
|
||||||
size_state,
|
|
||||||
size_action,
|
|
||||||
true,
|
|
||||||
ACTOR_LEARNING_RATE,
|
|
||||||
CRITIC_LEARNING_RATE,
|
|
||||||
GAMMA,
|
|
||||||
TAU,
|
|
||||||
REPLAY_BUFFER_CAPACITY,
|
|
||||||
ddpg::OuNoise::new(MU, THETA, SIGMA, size_action)?,
|
|
||||||
)?;
|
|
||||||
|
|
||||||
let mut rng = rand::thread_rng();
|
|
||||||
|
|
||||||
for episode in 0..MAX_EPISODES {
|
|
||||||
// let mut state = env.reset(episode as u64)?;
|
|
||||||
let mut state = env.reset(rng.gen::<u64>())?;
|
|
||||||
|
|
||||||
let mut total_reward = 0.0;
|
|
||||||
for _ in 0..EPISODE_LENGTH {
|
|
||||||
let mut action = 2.0 * agent.actions(&state)?;
|
|
||||||
action = action.clamp(-2.0, 2.0);
|
|
||||||
|
|
||||||
let step = env.step(vec![action])?;
|
|
||||||
total_reward += step.reward;
|
|
||||||
|
|
||||||
agent.remember(
|
|
||||||
&state,
|
|
||||||
&Tensor::new(vec![action], &Device::Cpu)?,
|
|
||||||
&Tensor::new(vec![step.reward as f32], &Device::Cpu)?,
|
|
||||||
&step.state,
|
|
||||||
step.terminated,
|
|
||||||
step.truncated,
|
|
||||||
);
|
|
||||||
|
|
||||||
if step.terminated || step.truncated {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
state = step.state;
|
|
||||||
}
|
|
||||||
|
|
||||||
println!("episode {episode} with total reward of {total_reward}");
|
|
||||||
|
|
||||||
for _ in 0..TRAINING_ITERATIONS {
|
|
||||||
agent.train(TRAINING_BATCH_SIZE)?;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
println!("Testing...");
|
|
||||||
agent.train = false;
|
|
||||||
for episode in 0..10 {
|
|
||||||
// let mut state = env.reset(episode as u64)?;
|
|
||||||
let mut state = env.reset(rng.gen::<u64>())?;
|
|
||||||
let mut total_reward = 0.0;
|
|
||||||
for _ in 0..EPISODE_LENGTH {
|
|
||||||
let mut action = 2.0 * agent.actions(&state)?;
|
|
||||||
action = action.clamp(-2.0, 2.0);
|
|
||||||
|
|
||||||
let step = env.step(vec![action])?;
|
|
||||||
total_reward += step.reward;
|
|
||||||
|
|
||||||
if step.terminated || step.truncated {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
state = step.state;
|
|
||||||
}
|
|
||||||
println!("episode {episode} with total reward of {total_reward}");
|
|
||||||
}
|
}
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
@ -0,0 +1,146 @@
|
|||||||
|
use super::gym_env::{GymEnv, Step};
|
||||||
|
use candle::{DType, Device, Error, Module, Result, Tensor};
|
||||||
|
use candle_nn::{
|
||||||
|
linear, ops::log_softmax, ops::softmax, sequential::seq, Activation, AdamW, Optimizer,
|
||||||
|
ParamsAdamW, VarBuilder, VarMap,
|
||||||
|
};
|
||||||
|
use rand::{distributions::Distribution, rngs::ThreadRng, Rng};
|
||||||
|
|
||||||
|
fn new_model(
|
||||||
|
input_shape: &[usize],
|
||||||
|
num_actions: usize,
|
||||||
|
dtype: DType,
|
||||||
|
device: &Device,
|
||||||
|
) -> Result<(impl Module, VarMap)> {
|
||||||
|
let input_size = input_shape.iter().product();
|
||||||
|
|
||||||
|
let mut varmap = VarMap::new();
|
||||||
|
let var_builder = VarBuilder::from_varmap(&varmap, dtype, device);
|
||||||
|
|
||||||
|
let model = seq()
|
||||||
|
.add(linear(input_size, 32, var_builder.pp("lin1"))?)
|
||||||
|
.add(Activation::Relu)
|
||||||
|
.add(linear(32, num_actions, var_builder.pp("lin2"))?);
|
||||||
|
|
||||||
|
Ok((model, varmap))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn accumulate_rewards(steps: &[Step<i64>]) -> Vec<f64> {
|
||||||
|
let mut rewards: Vec<f64> = steps.iter().map(|s| s.reward).collect();
|
||||||
|
let mut acc_reward = 0f64;
|
||||||
|
for (i, reward) in rewards.iter_mut().enumerate().rev() {
|
||||||
|
if steps[i].terminated {
|
||||||
|
acc_reward = 0.0;
|
||||||
|
}
|
||||||
|
acc_reward += *reward;
|
||||||
|
*reward = acc_reward;
|
||||||
|
}
|
||||||
|
rewards
|
||||||
|
}
|
||||||
|
|
||||||
|
fn weighted_sample(probs: Vec<f32>, rng: &mut ThreadRng) -> Result<usize> {
|
||||||
|
let distribution = rand::distributions::WeightedIndex::new(probs).map_err(Error::wrap)?;
|
||||||
|
let mut rng = rng;
|
||||||
|
Ok(distribution.sample(&mut rng))
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn run() -> Result<()> {
|
||||||
|
let env = GymEnv::new("CartPole-v1")?;
|
||||||
|
|
||||||
|
println!("action space: {:?}", env.action_space());
|
||||||
|
println!("observation space: {:?}", env.observation_space());
|
||||||
|
|
||||||
|
let (model, varmap) = new_model(
|
||||||
|
env.observation_space(),
|
||||||
|
env.action_space(),
|
||||||
|
DType::F32,
|
||||||
|
&Device::Cpu,
|
||||||
|
)?;
|
||||||
|
|
||||||
|
let optimizer_params = ParamsAdamW {
|
||||||
|
lr: 0.01,
|
||||||
|
weight_decay: 0.01,
|
||||||
|
..Default::default()
|
||||||
|
};
|
||||||
|
|
||||||
|
let mut optimizer = AdamW::new(varmap.all_vars(), optimizer_params)?;
|
||||||
|
|
||||||
|
let mut rng = rand::thread_rng();
|
||||||
|
|
||||||
|
for epoch_idx in 0..100 {
|
||||||
|
let mut state = env.reset(rng.gen::<u64>())?;
|
||||||
|
let mut steps: Vec<Step<i64>> = vec![];
|
||||||
|
|
||||||
|
loop {
|
||||||
|
let action = {
|
||||||
|
let action_probs: Vec<f32> =
|
||||||
|
softmax(&model.forward(&state.detach()?.unsqueeze(0)?)?, 1)?
|
||||||
|
.squeeze(0)?
|
||||||
|
.to_vec1()?;
|
||||||
|
weighted_sample(action_probs, &mut rng)? as i64
|
||||||
|
};
|
||||||
|
|
||||||
|
let step = env.step(action)?;
|
||||||
|
steps.push(step.copy_with_obs(&state));
|
||||||
|
|
||||||
|
if step.terminated || step.truncated {
|
||||||
|
state = env.reset(rng.gen::<u64>())?;
|
||||||
|
if steps.len() > 5000 {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
state = step.state;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let total_reward: f64 = steps.iter().map(|s| s.reward).sum();
|
||||||
|
let episodes: i64 = steps
|
||||||
|
.iter()
|
||||||
|
.map(|s| (s.terminated || s.truncated) as i64)
|
||||||
|
.sum();
|
||||||
|
println!(
|
||||||
|
"epoch: {:<3} episodes: {:<5} avg reward per episode: {:.2}",
|
||||||
|
epoch_idx,
|
||||||
|
episodes,
|
||||||
|
total_reward / episodes as f64
|
||||||
|
);
|
||||||
|
|
||||||
|
let batch_size = steps.len();
|
||||||
|
|
||||||
|
let rewards = Tensor::from_vec(accumulate_rewards(&steps), batch_size, &Device::Cpu)?
|
||||||
|
.to_dtype(DType::F32)?
|
||||||
|
.detach()?;
|
||||||
|
|
||||||
|
let actions_mask = {
|
||||||
|
let actions: Vec<i64> = steps.iter().map(|s| s.action).collect();
|
||||||
|
let actions_mask: Vec<Tensor> = actions
|
||||||
|
.iter()
|
||||||
|
.map(|&action| {
|
||||||
|
// One-hot encoding
|
||||||
|
let mut action_mask = vec![0.0; env.action_space()];
|
||||||
|
action_mask[action as usize] = 1.0;
|
||||||
|
|
||||||
|
Tensor::from_vec(action_mask, env.action_space(), &Device::Cpu)
|
||||||
|
.unwrap()
|
||||||
|
.to_dtype(DType::F32)
|
||||||
|
.unwrap()
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
Tensor::stack(&actions_mask, 0)?.detach()?
|
||||||
|
};
|
||||||
|
|
||||||
|
let states = {
|
||||||
|
let states: Vec<Tensor> = steps.into_iter().map(|s| s.state).collect();
|
||||||
|
Tensor::stack(&states, 0)?.detach()?
|
||||||
|
};
|
||||||
|
|
||||||
|
let log_probs = actions_mask
|
||||||
|
.mul(&log_softmax(&model.forward(&states)?, 1)?)?
|
||||||
|
.sum(1)?;
|
||||||
|
|
||||||
|
let loss = rewards.mul(&log_probs)?.neg()?.mean_all()?;
|
||||||
|
optimizer.backward_step(&loss)?;
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
@ -29,7 +29,7 @@ e.g.:
|
|||||||
|
|
||||||
```bash
|
```bash
|
||||||
cargo run --example stable-diffusion --release --features=cuda,cudnn \
|
cargo run --example stable-diffusion --release --features=cuda,cudnn \
|
||||||
-- --prompt "a cosmonaut on a horse (hd, realistic, high-def) --sd-version turbo"
|
-- --prompt "a cosmonaut on a horse (hd, realistic, high-def)" --sd-version turbo
|
||||||
```
|
```
|
||||||
|
|
||||||
The default scheduler for the v1.5, v2.1 and XL 1.0 version is the Denoising
|
The default scheduler for the v1.5, v2.1 and XL 1.0 version is the Denoising
|
||||||
|
@ -147,7 +147,7 @@ fn conv(vb: VarBuilder, index: usize, p: usize, b: &Block) -> Result<(usize, Bl)
|
|||||||
let func = candle_nn::func(move |xs| {
|
let func = candle_nn::func(move |xs| {
|
||||||
let xs = conv.forward(xs)?;
|
let xs = conv.forward(xs)?;
|
||||||
let xs = match &bn {
|
let xs = match &bn {
|
||||||
Some(bn) => bn.forward(&xs)?,
|
Some(bn) => xs.apply_t(bn, false)?,
|
||||||
None => xs,
|
None => xs,
|
||||||
};
|
};
|
||||||
let xs = if leaky {
|
let xs = if leaky {
|
||||||
|
62
candle-flash-attn/kernels/alibi.h
Normal file
62
candle-flash-attn/kernels/alibi.h
Normal file
@ -0,0 +1,62 @@
|
|||||||
|
#include <cmath>
|
||||||
|
|
||||||
|
#include <cute/tensor.hpp>
|
||||||
|
|
||||||
|
#include <cutlass/cutlass.h>
|
||||||
|
#include <cutlass/array.h>
|
||||||
|
|
||||||
|
#include "utils.h"
|
||||||
|
|
||||||
|
namespace flash {
|
||||||
|
|
||||||
|
using namespace cute;
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
template <bool Is_causal, typename Engine, typename Layout>
|
||||||
|
inline __device__ void apply_alibi(Tensor<Engine, Layout> &tensor,
|
||||||
|
const int col_idx_offset_,
|
||||||
|
const int max_seqlen_k,
|
||||||
|
const int row_idx_offset,
|
||||||
|
const int max_seqlen_q,
|
||||||
|
const int warp_row_stride,
|
||||||
|
const float alibi_slope) {
|
||||||
|
// tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N))
|
||||||
|
static_assert(Layout::rank == 2, "Only support 2D Tensor");
|
||||||
|
const int lane_id = threadIdx.x % 32;
|
||||||
|
const int col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2;
|
||||||
|
if constexpr (Is_causal) { // Simpler, we add the same bias vector to all rows
|
||||||
|
#pragma unroll
|
||||||
|
for (int nj = 0; nj < size<1, 1>(tensor); ++nj) {
|
||||||
|
const int col_idx_base = col_idx_offset + nj * 8;
|
||||||
|
#pragma unroll
|
||||||
|
for (int j = 0; j < size<1, 0>(tensor); ++j) {
|
||||||
|
const int col_idx = col_idx_base + j;
|
||||||
|
#pragma unroll
|
||||||
|
for (int mi = 0; mi < size<0>(tensor); ++mi) {
|
||||||
|
tensor(mi, make_coord(j, nj)) += alibi_slope * col_idx;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else { // Bias depends on both row_idx and col_idx
|
||||||
|
#pragma unroll
|
||||||
|
for (int mi = 0; mi < size<0, 1>(tensor); ++mi) {
|
||||||
|
const int row_idx_base = row_idx_offset + mi * warp_row_stride;
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < size<0, 0>(tensor); ++i) {
|
||||||
|
const int row_idx = row_idx_base + i * 8;
|
||||||
|
#pragma unroll
|
||||||
|
for (int nj = 0; nj < size<1, 1>(tensor); ++nj) {
|
||||||
|
const int col_idx_base = col_idx_offset + nj * 8;
|
||||||
|
#pragma unroll
|
||||||
|
for (int j = 0; j < size<1, 0>(tensor); ++j) {
|
||||||
|
const int col_idx = col_idx_base + j;
|
||||||
|
tensor(make_coord(i, mi), make_coord(j, nj)) -= alibi_slope * abs(row_idx + max_seqlen_k - max_seqlen_q - col_idx);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace flash
|
@ -14,9 +14,12 @@ struct BlockInfo {
|
|||||||
template<typename Params>
|
template<typename Params>
|
||||||
__device__ BlockInfo(const Params ¶ms, const int bidb)
|
__device__ BlockInfo(const Params ¶ms, const int bidb)
|
||||||
: sum_s_q(!Varlen || params.cu_seqlens_q == nullptr ? -1 : params.cu_seqlens_q[bidb])
|
: sum_s_q(!Varlen || params.cu_seqlens_q == nullptr ? -1 : params.cu_seqlens_q[bidb])
|
||||||
, sum_s_k(!Varlen || params.cu_seqlens_k == nullptr ? -1 : params.cu_seqlens_k[bidb])
|
, sum_s_k(!Varlen || params.cu_seqlens_k == nullptr || !params.is_seqlens_k_cumulative ? -1 : params.cu_seqlens_k[bidb])
|
||||||
, actual_seqlen_q(!Varlen || params.cu_seqlens_q == nullptr ? params.seqlen_q : params.cu_seqlens_q[bidb + 1] - sum_s_q)
|
, actual_seqlen_q(!Varlen || params.cu_seqlens_q == nullptr ? params.seqlen_q : params.cu_seqlens_q[bidb + 1] - sum_s_q)
|
||||||
, actual_seqlen_k(!Varlen || params.cu_seqlens_k == nullptr ? params.seqlen_k : params.cu_seqlens_k[bidb + 1] - sum_s_k)
|
// If is_seqlens_k_cumulative, then seqlen_k is cu_seqlens_k[bidb + 1] - cu_seqlens_k[bidb].
|
||||||
|
// Otherwise it's cu_seqlens_k[bidb], i.e., we use cu_seqlens_k to store the sequence lengths of K.
|
||||||
|
, seqlen_k_cache(!Varlen || params.cu_seqlens_k == nullptr ? params.seqlen_k : (params.is_seqlens_k_cumulative ? params.cu_seqlens_k[bidb + 1] - sum_s_k : params.cu_seqlens_k[bidb]))
|
||||||
|
, actual_seqlen_k(params.seqused_k ? params.seqused_k[bidb] : seqlen_k_cache + (params.knew_ptr == nullptr ? 0 : params.seqlen_knew))
|
||||||
{
|
{
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -32,8 +35,10 @@ struct BlockInfo {
|
|||||||
|
|
||||||
const int sum_s_q;
|
const int sum_s_q;
|
||||||
const int sum_s_k;
|
const int sum_s_k;
|
||||||
const uint32_t actual_seqlen_q;
|
const int actual_seqlen_q;
|
||||||
const uint32_t actual_seqlen_k;
|
// We have to have seqlen_k_cache declared before actual_seqlen_k, otherwise actual_seqlen_k is set to 0.
|
||||||
|
const int seqlen_k_cache;
|
||||||
|
const int actual_seqlen_k;
|
||||||
};
|
};
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
@ -7,15 +7,6 @@
|
|||||||
#include <cuda.h>
|
#include <cuda.h>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
// #ifdef OLD_GENERATOR_PATH
|
|
||||||
// #include <ATen/CUDAGeneratorImpl.h>
|
|
||||||
// #else
|
|
||||||
// #include <ATen/cuda/CUDAGeneratorImpl.h>
|
|
||||||
// #endif
|
|
||||||
//
|
|
||||||
// #include <ATen/cuda/CUDAGraphsUtils.cuh>
|
|
||||||
|
|
||||||
|
|
||||||
constexpr int TOTAL_DIM = 0;
|
constexpr int TOTAL_DIM = 0;
|
||||||
constexpr int H_DIM = 1;
|
constexpr int H_DIM = 1;
|
||||||
constexpr int D_DIM = 2;
|
constexpr int D_DIM = 2;
|
||||||
@ -53,6 +44,7 @@ struct Flash_fwd_params : public Qkv_params {
|
|||||||
|
|
||||||
// The O matrix (output).
|
// The O matrix (output).
|
||||||
void * __restrict__ o_ptr;
|
void * __restrict__ o_ptr;
|
||||||
|
void * __restrict__ oaccum_ptr;
|
||||||
|
|
||||||
// The stride between rows of O.
|
// The stride between rows of O.
|
||||||
index_t o_batch_stride;
|
index_t o_batch_stride;
|
||||||
@ -64,9 +56,10 @@ struct Flash_fwd_params : public Qkv_params {
|
|||||||
|
|
||||||
// The pointer to the softmax sum.
|
// The pointer to the softmax sum.
|
||||||
void * __restrict__ softmax_lse_ptr;
|
void * __restrict__ softmax_lse_ptr;
|
||||||
|
void * __restrict__ softmax_lseaccum_ptr;
|
||||||
|
|
||||||
// The dimensions.
|
// The dimensions.
|
||||||
int b, seqlen_q, seqlen_k, d, seqlen_q_rounded, seqlen_k_rounded, d_rounded;
|
int b, seqlen_q, seqlen_k, seqlen_knew, d, seqlen_q_rounded, seqlen_k_rounded, d_rounded, rotary_dim;
|
||||||
|
|
||||||
// The scaling factors for the kernel.
|
// The scaling factors for the kernel.
|
||||||
float scale_softmax;
|
float scale_softmax;
|
||||||
@ -76,8 +69,30 @@ struct Flash_fwd_params : public Qkv_params {
|
|||||||
int * __restrict__ cu_seqlens_q;
|
int * __restrict__ cu_seqlens_q;
|
||||||
int * __restrict__ cu_seqlens_k;
|
int * __restrict__ cu_seqlens_k;
|
||||||
|
|
||||||
|
// If provided, the actual length of each k sequence.
|
||||||
|
int * __restrict__ seqused_k;
|
||||||
|
|
||||||
int *__restrict__ blockmask;
|
int *__restrict__ blockmask;
|
||||||
|
|
||||||
|
// The K_new and V_new matrices.
|
||||||
|
void * __restrict__ knew_ptr;
|
||||||
|
void * __restrict__ vnew_ptr;
|
||||||
|
|
||||||
|
// The stride between rows of the Q, K and V matrices.
|
||||||
|
index_t knew_batch_stride;
|
||||||
|
index_t vnew_batch_stride;
|
||||||
|
index_t knew_row_stride;
|
||||||
|
index_t vnew_row_stride;
|
||||||
|
index_t knew_head_stride;
|
||||||
|
index_t vnew_head_stride;
|
||||||
|
|
||||||
|
// The cos and sin matrices for rotary embedding.
|
||||||
|
void * __restrict__ rotary_cos_ptr;
|
||||||
|
void * __restrict__ rotary_sin_ptr;
|
||||||
|
|
||||||
|
// The indices to index into the KV cache.
|
||||||
|
int *__restrict__ cache_batch_idx;
|
||||||
|
|
||||||
// The dropout probability (probability of keeping an activation).
|
// The dropout probability (probability of keeping an activation).
|
||||||
float p_dropout;
|
float p_dropout;
|
||||||
// uint32_t p_dropout_in_uint;
|
// uint32_t p_dropout_in_uint;
|
||||||
@ -88,11 +103,22 @@ struct Flash_fwd_params : public Qkv_params {
|
|||||||
float rp_dropout;
|
float rp_dropout;
|
||||||
float scale_softmax_rp_dropout;
|
float scale_softmax_rp_dropout;
|
||||||
|
|
||||||
// Random state.
|
// Local window size
|
||||||
// at::PhiloxCudaState philox_args;
|
int window_size_left, window_size_right;
|
||||||
|
|
||||||
bool is_bf16;
|
bool is_bf16;
|
||||||
bool is_causal;
|
bool is_causal;
|
||||||
|
|
||||||
|
// If is_seqlens_k_cumulative, then seqlen_k is cu_seqlens_k[bidb + 1] - cu_seqlens_k[bidb].
|
||||||
|
// Otherwise it's cu_seqlens_k[bidb], i.e., we use cu_seqlens_k to store the sequence lengths of K.
|
||||||
|
bool is_seqlens_k_cumulative;
|
||||||
|
|
||||||
|
bool is_rotary_interleaved;
|
||||||
|
|
||||||
|
int num_splits; // For split-KV version
|
||||||
|
|
||||||
|
void * __restrict__ alibi_slopes_ptr;
|
||||||
|
index_t alibi_slopes_batch_stride;
|
||||||
};
|
};
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
@ -132,10 +158,14 @@ struct Flash_bwd_params : public Flash_fwd_params {
|
|||||||
|
|
||||||
// The pointer to the softmax d sum.
|
// The pointer to the softmax d sum.
|
||||||
void *__restrict__ dsoftmax_sum;
|
void *__restrict__ dsoftmax_sum;
|
||||||
|
|
||||||
|
bool deterministic;
|
||||||
|
index_t dq_accum_split_stride;
|
||||||
};
|
};
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
template<typename T, int Headdim> void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream);
|
template<typename T, int Headdim> void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream);
|
||||||
|
template<typename T, int Headdim> void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream);
|
||||||
|
|
||||||
template<typename T, int Headdim> void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure);
|
template<typename T, int Headdim> void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure);
|
||||||
|
@ -1,15 +1,13 @@
|
|||||||
#include "flash_fwd_launch_template.h"
|
#include "flash_fwd_launch_template.h"
|
||||||
|
|
||||||
// void run_mha_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
void run_mha_fwd(Flash_fwd_params ¶ms, cudaStream_t stream, bool force_split_kernel=false) {
|
||||||
// FWD_HEADDIM_SWITCH(params.d, [&] {
|
|
||||||
// run_mha_fwd_<cutlass::half_t, kHeadDim>(params, stream);
|
|
||||||
// });
|
|
||||||
// }
|
|
||||||
|
|
||||||
void run_mha_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
|
||||||
FP16_SWITCH(!params.is_bf16, [&] {
|
FP16_SWITCH(!params.is_bf16, [&] {
|
||||||
FWD_HEADDIM_SWITCH(params.d, [&] {
|
FWD_HEADDIM_SWITCH(params.d, [&] {
|
||||||
|
// if (params.num_splits <= 1 && !force_split_kernel) { // If we don't set it num_splits == 0
|
||||||
run_mha_fwd_<elem_type, kHeadDim>(params, stream);
|
run_mha_fwd_<elem_type, kHeadDim>(params, stream);
|
||||||
|
// } else {
|
||||||
|
// run_mha_fwd_splitkv_dispatch<elem_type, kHeadDim>(params, stream);
|
||||||
|
// }
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
@ -20,6 +18,7 @@ extern "C" void run_mha(
|
|||||||
void *v_ptr,
|
void *v_ptr,
|
||||||
void *o_ptr,
|
void *o_ptr,
|
||||||
void *softmax_lse_ptr,
|
void *softmax_lse_ptr,
|
||||||
|
void *alibi_slopes_ptr,
|
||||||
|
|
||||||
int32_t *cu_seqlens_q_ptr,
|
int32_t *cu_seqlens_q_ptr,
|
||||||
int32_t *cu_seqlens_k_ptr,
|
int32_t *cu_seqlens_k_ptr,
|
||||||
@ -28,6 +27,7 @@ extern "C" void run_mha(
|
|||||||
uint32_t k_batch_stride,
|
uint32_t k_batch_stride,
|
||||||
uint32_t v_batch_stride,
|
uint32_t v_batch_stride,
|
||||||
uint32_t o_batch_stride,
|
uint32_t o_batch_stride,
|
||||||
|
uint32_t alibi_slopes_batch_stride,
|
||||||
|
|
||||||
uint32_t q_row_stride,
|
uint32_t q_row_stride,
|
||||||
uint32_t k_row_stride,
|
uint32_t k_row_stride,
|
||||||
@ -51,8 +51,11 @@ extern "C" void run_mha(
|
|||||||
uint32_t seqlen_q_rounded,
|
uint32_t seqlen_q_rounded,
|
||||||
uint32_t seqlen_k_rounded,
|
uint32_t seqlen_k_rounded,
|
||||||
|
|
||||||
|
int is_bf16,
|
||||||
int is_causal,
|
int is_causal,
|
||||||
int is_bf16
|
|
||||||
|
int window_size_left,
|
||||||
|
int window_size_right
|
||||||
) {
|
) {
|
||||||
Flash_fwd_params params;
|
Flash_fwd_params params;
|
||||||
// Reset the parameters
|
// Reset the parameters
|
||||||
@ -65,12 +68,14 @@ extern "C" void run_mha(
|
|||||||
params.o_ptr = o_ptr;
|
params.o_ptr = o_ptr;
|
||||||
|
|
||||||
params.softmax_lse_ptr = softmax_lse_ptr;
|
params.softmax_lse_ptr = softmax_lse_ptr;
|
||||||
|
params.alibi_slopes_ptr = alibi_slopes_ptr;
|
||||||
|
|
||||||
// All stride are in elements, not bytes.
|
// All stride are in elements, not bytes.
|
||||||
params.q_batch_stride = q_batch_stride;
|
params.q_batch_stride = q_batch_stride;
|
||||||
params.k_batch_stride = k_batch_stride;
|
params.k_batch_stride = k_batch_stride;
|
||||||
params.v_batch_stride = v_batch_stride;
|
params.v_batch_stride = v_batch_stride;
|
||||||
params.o_batch_stride = o_batch_stride;
|
params.o_batch_stride = o_batch_stride;
|
||||||
|
params.alibi_slopes_batch_stride = alibi_slopes_batch_stride;
|
||||||
|
|
||||||
params.q_row_stride = q_row_stride;
|
params.q_row_stride = q_row_stride;
|
||||||
params.k_row_stride = k_row_stride;
|
params.k_row_stride = k_row_stride;
|
||||||
@ -92,7 +97,6 @@ extern "C" void run_mha(
|
|||||||
params.seqlen_k_rounded = seqlen_k_rounded;
|
params.seqlen_k_rounded = seqlen_k_rounded;
|
||||||
params.d = d;
|
params.d = d;
|
||||||
params.d_rounded = d_rounded;
|
params.d_rounded = d_rounded;
|
||||||
params.is_causal = is_causal;
|
|
||||||
|
|
||||||
// Set the different scale values.
|
// Set the different scale values.
|
||||||
params.scale_softmax = softmax_scale;
|
params.scale_softmax = softmax_scale;
|
||||||
@ -106,6 +110,14 @@ extern "C" void run_mha(
|
|||||||
params.cu_seqlens_q = cu_seqlens_q_ptr;
|
params.cu_seqlens_q = cu_seqlens_q_ptr;
|
||||||
params.cu_seqlens_k = cu_seqlens_k_ptr;
|
params.cu_seqlens_k = cu_seqlens_k_ptr;
|
||||||
params.p_ptr = nullptr; // used for `return_softmax`.
|
params.p_ptr = nullptr; // used for `return_softmax`.
|
||||||
|
params.seqused_k = nullptr;
|
||||||
|
|
||||||
|
params.is_causal = is_causal;
|
||||||
|
params.window_size_left = window_size_left;
|
||||||
|
params.window_size_right = window_size_right;
|
||||||
|
|
||||||
|
params.is_seqlens_k_cumulative = true;
|
||||||
|
params.num_splits = 1;
|
||||||
|
|
||||||
cudaStream_t stream = 0; // Use the default stream.
|
cudaStream_t stream = 0; // Use the default stream.
|
||||||
run_mha_fwd(params, stream);
|
run_mha_fwd(params, stream);
|
||||||
|
@ -1,18 +1,9 @@
|
|||||||
// Copyright (c) 2023, Tri Dao.
|
// Copyright (c) 2023, Tri Dao.
|
||||||
|
|
||||||
// Splitting the different head dimensions to different files to speed up compilation.
|
// Splitting the different head dimensions to different files to speed up compilation.
|
||||||
|
// This file is auto-generated. See "generate_kernels.py"
|
||||||
|
|
||||||
#include "flash_fwd_launch_template.h"
|
#include "flash_fwd_launch_template.h"
|
||||||
|
|
||||||
// template<>
|
|
||||||
// void run_mha_fwd_<cutlass::bfloat16_t, 128>(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
|
||||||
// using elem_type = cutlass::bfloat16_t;
|
|
||||||
// if (params.p_dropout == 1.f) {
|
|
||||||
// run_flash_fwd<Flash_fwd_kernel_traits<128, 128, 64, 4, false, false, elem_type>, false>(params, stream);
|
|
||||||
// } else {
|
|
||||||
// run_flash_fwd<Flash_fwd_kernel_traits<128, 128, 32, 4, false, false, elem_type>, true>(params, stream);
|
|
||||||
// }
|
|
||||||
// }
|
|
||||||
template<>
|
template<>
|
||||||
void run_mha_fwd_<cutlass::bfloat16_t, 128>(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
void run_mha_fwd_<cutlass::bfloat16_t, 128>(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||||
run_mha_fwd_hdim128<cutlass::bfloat16_t>(params, stream);
|
run_mha_fwd_hdim128<cutlass::bfloat16_t>(params, stream);
|
||||||
|
@ -1,31 +1,9 @@
|
|||||||
// Copyright (c) 2023, Tri Dao.
|
// Copyright (c) 2023, Tri Dao.
|
||||||
|
|
||||||
// Splitting the different head dimensions to different files to speed up compilation.
|
// Splitting the different head dimensions to different files to speed up compilation.
|
||||||
|
// This file is auto-generated. See "generate_kernels.py"
|
||||||
|
|
||||||
#include "flash_fwd_launch_template.h"
|
#include "flash_fwd_launch_template.h"
|
||||||
|
|
||||||
// template<>
|
|
||||||
// void run_mha_fwd_<cutlass::half_t, 128>(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
|
||||||
// using elem_type = cutlass::half_t;
|
|
||||||
// if (params.p_dropout == 1.f) {
|
|
||||||
// // Using 8 warps (128 x 128 and 256 x 64) is 28% slower for seqlen=2k
|
|
||||||
// run_flash_fwd<Flash_fwd_kernel_traits<128, 128, 64, 4, false, false, elem_type>, false>(params, stream);
|
|
||||||
// // run_flash_fwd<Flash_fwd_kernel_traits<128, 128, 64, 4, true, false, elem_type>, false>(params, stream);
|
|
||||||
// // run_flash_fwd<Flash_fwd_kernel_traits<128, 128, 64, 4, false, true, elem_type>, false>(params, stream);
|
|
||||||
// // run_flash_fwd<Flash_fwd_kernel_traits<128, 128, 64, 4, true, true, elem_type>, false>(params, stream);
|
|
||||||
// run_flash_fwd<Flash_fwd_kernel_traits<128, 128, 32, 4, false, false, elem_type>, false>(params, stream);
|
|
||||||
// run_flash_fwd<Flash_fwd_kernel_traits<128, 64, 64, 4, false, false, elem_type>, false>(params, stream);
|
|
||||||
// run_flash_fwd<Flash_fwd_kernel_traits<128, 64, 128, 4, false, false, elem_type>, false>(params, stream);
|
|
||||||
// // 1st ones are good for H100, A100
|
|
||||||
// // 2nd one is good for A6000 bc we get slightly better occupancy
|
|
||||||
// } else {
|
|
||||||
// run_flash_fwd<Flash_fwd_kernel_traits<128, 128, 32, 4, false, false, elem_type>, true>(params, stream);
|
|
||||||
// run_flash_fwd<Flash_fwd_kernel_traits<128, 128, 32, 4, true, false, elem_type>, true>(params, stream);
|
|
||||||
// run_flash_fwd<Flash_fwd_kernel_traits<128, 128, 32, 4, true, true, elem_type>, true>(params, stream);
|
|
||||||
// // 1st one is good for H100, A100, A6000
|
|
||||||
// }
|
|
||||||
// }
|
|
||||||
|
|
||||||
template<>
|
template<>
|
||||||
void run_mha_fwd_<cutlass::half_t, 128>(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
void run_mha_fwd_<cutlass::half_t, 128>(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||||
run_mha_fwd_hdim128<cutlass::half_t>(params, stream);
|
run_mha_fwd_hdim128<cutlass::half_t>(params, stream);
|
||||||
|
@ -1,16 +1,9 @@
|
|||||||
// Copyright (c) 2023, Tri Dao.
|
// Copyright (c) 2023, Tri Dao.
|
||||||
|
|
||||||
// Splitting the different head dimensions to different files to speed up compilation.
|
// Splitting the different head dimensions to different files to speed up compilation.
|
||||||
|
// This file is auto-generated. See "generate_kernels.py"
|
||||||
|
|
||||||
#include "flash_fwd_launch_template.h"
|
#include "flash_fwd_launch_template.h"
|
||||||
|
|
||||||
// template<>
|
|
||||||
// void run_mha_fwd_<cutlass::bfloat16_t, 160>(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
|
||||||
// using elem_type = cutlass::bfloat16_t;
|
|
||||||
// BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
|
|
||||||
// run_flash_fwd<Flash_fwd_kernel_traits<160, 128, 32, 4, false, false, elem_type>, Is_dropout>(params, stream);
|
|
||||||
// });
|
|
||||||
// }
|
|
||||||
template<>
|
template<>
|
||||||
void run_mha_fwd_<cutlass::bfloat16_t, 160>(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
void run_mha_fwd_<cutlass::bfloat16_t, 160>(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||||
run_mha_fwd_hdim160<cutlass::bfloat16_t>(params, stream);
|
run_mha_fwd_hdim160<cutlass::bfloat16_t>(params, stream);
|
||||||
|
@ -1,26 +1,9 @@
|
|||||||
// Copyright (c) 2023, Tri Dao.
|
// Copyright (c) 2023, Tri Dao.
|
||||||
|
|
||||||
// Splitting the different head dimensions to different files to speed up compilation.
|
// Splitting the different head dimensions to different files to speed up compilation.
|
||||||
|
// This file is auto-generated. See "generate_kernels.py"
|
||||||
|
|
||||||
#include "flash_fwd_launch_template.h"
|
#include "flash_fwd_launch_template.h"
|
||||||
|
|
||||||
// template<>
|
|
||||||
// void run_mha_fwd_<cutlass::half_t, 160>(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
|
||||||
// using elem_type = cutlass::half_t;
|
|
||||||
// BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
|
|
||||||
// run_flash_fwd<Flash_fwd_kernel_traits<160, 128, 32, 4, false, false, elem_type>, Is_dropout>(params, stream);
|
|
||||||
// run_flash_fwd<Flash_fwd_kernel_traits<160, 128, 32, 4, false, true, elem_type>, Is_dropout>(params, stream);
|
|
||||||
// run_flash_fwd<Flash_fwd_kernel_traits<160, 128, 64, 4, false, false, elem_type>, Is_dropout>(params, stream);
|
|
||||||
// run_flash_fwd<Flash_fwd_kernel_traits<160, 64, 64, 4, false, false, elem_type>, Is_dropout>(params, stream);
|
|
||||||
// // run_flash_fwd<Flash_fwd_kernel_traits<160, 128, 64, 4, false, elem_type>>(params, stream);
|
|
||||||
// // run_flash_fwd<Flash_fwd_kernel_traits<160, 64, 128, 4, false, elem_type>>(params, stream);
|
|
||||||
// // run_flash_fwd<Flash_fwd_kernel_traits<160, 64, 64, 4, false, elem_type>>(params, stream);
|
|
||||||
// // run_flash_fwd<Flash_fwd_kernel_traits<160, 128, 64, 8, false, elem_type>>(params, stream);
|
|
||||||
// // run_flash_fwd<Flash_fwd_kernel_traits<160, 128, 128, 8, false, elem_type>>(params, stream);
|
|
||||||
// // For A6000, no-causal, 1st is fastest. causal, 4th is fastest.
|
|
||||||
// // For A100, H100, 1st is fastest.
|
|
||||||
// });
|
|
||||||
// }
|
|
||||||
template<>
|
template<>
|
||||||
void run_mha_fwd_<cutlass::half_t, 160>(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
void run_mha_fwd_<cutlass::half_t, 160>(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||||
run_mha_fwd_hdim160<cutlass::half_t>(params, stream);
|
run_mha_fwd_hdim160<cutlass::half_t>(params, stream);
|
||||||
|
@ -1,16 +1,10 @@
|
|||||||
// Copyright (c) 2023, Tri Dao.
|
// Copyright (c) 2023, Tri Dao.
|
||||||
|
|
||||||
// Splitting the different head dimensions to different files to speed up compilation.
|
// Splitting the different head dimensions to different files to speed up compilation.
|
||||||
|
// This file is auto-generated. See "generate_kernels.py"
|
||||||
|
|
||||||
#include "flash_fwd_launch_template.h"
|
#include "flash_fwd_launch_template.h"
|
||||||
|
|
||||||
// template<>
|
template<>
|
||||||
// void run_mha_fwd_<cutlass::bfloat16_t, 192>(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
void run_mha_fwd_<cutlass::bfloat16_t, 192>(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||||
// using elem_type = cutlass::bfloat16_t;
|
|
||||||
// BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
|
|
||||||
// run_flash_fwd<Flash_fwd_kernel_traits<192, 64, 64, 4, false, false, elem_type>, Is_dropout>(params, stream);
|
|
||||||
// });
|
|
||||||
// }
|
|
||||||
template<> void run_mha_fwd_<cutlass::bfloat16_t, 192>(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
|
||||||
run_mha_fwd_hdim192<cutlass::bfloat16_t>(params, stream);
|
run_mha_fwd_hdim192<cutlass::bfloat16_t>(params, stream);
|
||||||
}
|
}
|
||||||
|
@ -1,26 +1,9 @@
|
|||||||
// Copyright (c) 2023, Tri Dao.
|
// Copyright (c) 2023, Tri Dao.
|
||||||
|
|
||||||
// Splitting the different head dimensions to different files to speed up compilation.
|
// Splitting the different head dimensions to different files to speed up compilation.
|
||||||
|
// This file is auto-generated. See "generate_kernels.py"
|
||||||
|
|
||||||
#include "flash_fwd_launch_template.h"
|
#include "flash_fwd_launch_template.h"
|
||||||
|
|
||||||
// template<>
|
|
||||||
// void run_mha_fwd_<cutlass::half_t, 192>(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
|
||||||
// using elem_type = cutlass::half_t;
|
|
||||||
// BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
|
|
||||||
// run_flash_fwd<Flash_fwd_kernel_traits<192, 64, 64, 4, false, false, elem_type>, Is_dropout>(params, stream);
|
|
||||||
// run_flash_fwd<Flash_fwd_kernel_traits<192, 128, 32, 4, false, false, elem_type>, Is_dropout>(params, stream);
|
|
||||||
// run_flash_fwd<Flash_fwd_kernel_traits<192, 64, 32, 4, false, false, elem_type>, Is_dropout>(params, stream);
|
|
||||||
// // This one is slightly faster for causal?
|
|
||||||
// // run_flash_fwd<Flash_fwd_kernel_traits<192, 128, 64, 8, false, elem_type>>(params, stream);
|
|
||||||
// // run_flash_fwd<Flash_fwd_kernel_traits<192, 128, 32, 4, false, elem_type>>(params, stream);
|
|
||||||
// // run_flash_fwd<Flash_fwd_kernel_traits<192, 128, 64, 4, false, elem_type>>(params, stream);
|
|
||||||
// // run_flash_fwd<Flash_fwd_kernel_traits<192, 64, 128, 4, false, elem_type>>(params, stream);
|
|
||||||
// // run_flash_fwd<Flash_fwd_kernel_traits<192, 128, 128, 8, false, elem_type>>(params, stream);
|
|
||||||
// });
|
|
||||||
// // For A100 H100, 1st is faster with dropout, 3rd is faster without dropout
|
|
||||||
// // For A6000, 1st is faster when causal, 3rd is faster when not causal
|
|
||||||
// }
|
|
||||||
template<>
|
template<>
|
||||||
void run_mha_fwd_<cutlass::half_t, 192>(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
void run_mha_fwd_<cutlass::half_t, 192>(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||||
run_mha_fwd_hdim192<cutlass::half_t>(params, stream);
|
run_mha_fwd_hdim192<cutlass::half_t>(params, stream);
|
||||||
|
@ -1,9 +1,10 @@
|
|||||||
// Copyright (c) 2023, Tri Dao.
|
// Copyright (c) 2023, Tri Dao.
|
||||||
|
|
||||||
// Splitting the different head dimensions to different files to speed up compilation.
|
// Splitting the different head dimensions to different files to speed up compilation.
|
||||||
|
// This file is auto-generated. See "generate_kernels.py"
|
||||||
|
|
||||||
#include "flash_fwd_launch_template.h"
|
#include "flash_fwd_launch_template.h"
|
||||||
|
|
||||||
template<> void run_mha_fwd_<cutlass::bfloat16_t, 224>(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
template<>
|
||||||
|
void run_mha_fwd_<cutlass::bfloat16_t, 224>(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||||
run_mha_fwd_hdim224<cutlass::bfloat16_t>(params, stream);
|
run_mha_fwd_hdim224<cutlass::bfloat16_t>(params, stream);
|
||||||
}
|
}
|
||||||
|
@ -1,9 +1,10 @@
|
|||||||
// Copyright (c) 2023, Tri Dao.
|
// Copyright (c) 2023, Tri Dao.
|
||||||
|
|
||||||
// Splitting the different head dimensions to different files to speed up compilation.
|
// Splitting the different head dimensions to different files to speed up compilation.
|
||||||
|
// This file is auto-generated. See "generate_kernels.py"
|
||||||
|
|
||||||
#include "flash_fwd_launch_template.h"
|
#include "flash_fwd_launch_template.h"
|
||||||
|
|
||||||
template<> void run_mha_fwd_<cutlass::half_t, 224>(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
template<>
|
||||||
|
void run_mha_fwd_<cutlass::half_t, 224>(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||||
run_mha_fwd_hdim224<cutlass::half_t>(params, stream);
|
run_mha_fwd_hdim224<cutlass::half_t>(params, stream);
|
||||||
}
|
}
|
||||||
|
@ -1,9 +1,10 @@
|
|||||||
// Copyright (c) 2023, Tri Dao.
|
// Copyright (c) 2023, Tri Dao.
|
||||||
|
|
||||||
// Splitting the different head dimensions to different files to speed up compilation.
|
// Splitting the different head dimensions to different files to speed up compilation.
|
||||||
|
// This file is auto-generated. See "generate_kernels.py"
|
||||||
|
|
||||||
#include "flash_fwd_launch_template.h"
|
#include "flash_fwd_launch_template.h"
|
||||||
|
|
||||||
template<> void run_mha_fwd_<cutlass::bfloat16_t, 256>(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
template<>
|
||||||
|
void run_mha_fwd_<cutlass::bfloat16_t, 256>(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||||
run_mha_fwd_hdim256<cutlass::bfloat16_t>(params, stream);
|
run_mha_fwd_hdim256<cutlass::bfloat16_t>(params, stream);
|
||||||
}
|
}
|
||||||
|
@ -1,9 +1,10 @@
|
|||||||
// Copyright (c) 2023, Tri Dao.
|
// Copyright (c) 2023, Tri Dao.
|
||||||
|
|
||||||
// Splitting the different head dimensions to different files to speed up compilation.
|
// Splitting the different head dimensions to different files to speed up compilation.
|
||||||
|
// This file is auto-generated. See "generate_kernels.py"
|
||||||
|
|
||||||
#include "flash_fwd_launch_template.h"
|
#include "flash_fwd_launch_template.h"
|
||||||
|
|
||||||
template<> void run_mha_fwd_<cutlass::half_t, 256>(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
template<>
|
||||||
|
void run_mha_fwd_<cutlass::half_t, 256>(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||||
run_mha_fwd_hdim256<cutlass::half_t>(params, stream);
|
run_mha_fwd_hdim256<cutlass::half_t>(params, stream);
|
||||||
}
|
}
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
// Copyright (c) 2023, Tri Dao.
|
// Copyright (c) 2023, Tri Dao.
|
||||||
|
|
||||||
// Splitting the different head dimensions to different files to speed up compilation.
|
// Splitting the different head dimensions to different files to speed up compilation.
|
||||||
|
// This file is auto-generated. See "generate_kernels.py"
|
||||||
|
|
||||||
#include "flash_fwd_launch_template.h"
|
#include "flash_fwd_launch_template.h"
|
||||||
|
|
||||||
|
@ -1,22 +1,9 @@
|
|||||||
// Copyright (c) 2023, Tri Dao.
|
// Copyright (c) 2023, Tri Dao.
|
||||||
|
|
||||||
// Splitting the different head dimensions to different files to speed up compilation.
|
// Splitting the different head dimensions to different files to speed up compilation.
|
||||||
|
// This file is auto-generated. See "generate_kernels.py"
|
||||||
|
|
||||||
#include "flash_fwd_launch_template.h"
|
#include "flash_fwd_launch_template.h"
|
||||||
|
|
||||||
// template<>
|
|
||||||
// void run_mha_fwd_<cutlass::half_t, 32>(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
|
||||||
// using elem_type = cutlass::half_t;
|
|
||||||
// BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
|
|
||||||
// run_flash_fwd<Flash_fwd_kernel_traits<32, 128, 128, 4, false, false, elem_type>, Is_dropout>(params, stream);
|
|
||||||
// // For dropout there might be a lot of register spilling?
|
|
||||||
// // These two are very slow due to register spilling
|
|
||||||
// // run_flash_fwd<Flash_fwd_kernel_traits<32, 256, 128, 4, false, elem_type>>(params, stream);
|
|
||||||
// // run_flash_fwd<Flash_fwd_kernel_traits<32, 128, 256, 4, false, elem_type>>(params, stream);
|
|
||||||
// // This one is slightly slower
|
|
||||||
// // run_flash_fwd<Flash_fwd_kernel_traits<32, 256, 64, 4, false, elem_type>>(params, stream);
|
|
||||||
// });
|
|
||||||
// }
|
|
||||||
template<>
|
template<>
|
||||||
void run_mha_fwd_<cutlass::half_t, 32>(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
void run_mha_fwd_<cutlass::half_t, 32>(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||||
run_mha_fwd_hdim32<cutlass::half_t>(params, stream);
|
run_mha_fwd_hdim32<cutlass::half_t>(params, stream);
|
||||||
|
@ -1,18 +1,9 @@
|
|||||||
// Copyright (c) 2023, Tri Dao.
|
// Copyright (c) 2023, Tri Dao.
|
||||||
|
|
||||||
// Splitting the different head dimensions to different files to speed up compilation.
|
// Splitting the different head dimensions to different files to speed up compilation.
|
||||||
|
// This file is auto-generated. See "generate_kernels.py"
|
||||||
|
|
||||||
#include "flash_fwd_launch_template.h"
|
#include "flash_fwd_launch_template.h"
|
||||||
|
|
||||||
// template<>
|
|
||||||
// void run_mha_fwd_<cutlass::bfloat16_t, 64>(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
|
||||||
// using elem_type = cutlass::bfloat16_t;
|
|
||||||
// if (params.p_dropout == 1.f) {
|
|
||||||
// run_flash_fwd<Flash_fwd_kernel_traits<64, 128, 64, 4, true, false, elem_type>, false>(params, stream);
|
|
||||||
// } else {
|
|
||||||
// run_flash_fwd<Flash_fwd_kernel_traits<64, 128, 64, 4, false, false, elem_type>, true>(params, stream);
|
|
||||||
// }
|
|
||||||
// }
|
|
||||||
template<>
|
template<>
|
||||||
void run_mha_fwd_<cutlass::bfloat16_t, 64>(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
void run_mha_fwd_<cutlass::bfloat16_t, 64>(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||||
run_mha_fwd_hdim64<cutlass::bfloat16_t>(params, stream);
|
run_mha_fwd_hdim64<cutlass::bfloat16_t>(params, stream);
|
||||||
|
@ -1,25 +1,9 @@
|
|||||||
// Copyright (c) 2023, Tri Dao.
|
// Copyright (c) 2023, Tri Dao.
|
||||||
|
|
||||||
// Splitting the different head dimensions to different files to speed up compilation.
|
// Splitting the different head dimensions to different files to speed up compilation.
|
||||||
|
// This file is auto-generated. See "generate_kernels.py"
|
||||||
|
|
||||||
#include "flash_fwd_launch_template.h"
|
#include "flash_fwd_launch_template.h"
|
||||||
|
|
||||||
// template<>
|
|
||||||
// void run_mha_fwd_<cutlass::half_t, 64>(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
|
||||||
// using elem_type = cutlass::half_t;
|
|
||||||
// if (params.p_dropout == 1.f) {
|
|
||||||
// // Using 8 warps is 18% slower for seqlen=2k, 2 warps is 5% slower
|
|
||||||
// // Using block size (64 x 256) is 27% slower for seqlen=2k
|
|
||||||
// // Using block size (256 x 64) is 85% slower for seqlen=2k, because of register spilling
|
|
||||||
// run_flash_fwd<Flash_fwd_kernel_traits<64, 128, 128, 4, false, false, elem_type>, false>(params, stream);
|
|
||||||
// run_flash_fwd<Flash_fwd_kernel_traits<64, 128, 64, 4, true, false, elem_type>, false>(params, stream);
|
|
||||||
// run_flash_fwd<Flash_fwd_kernel_traits<64, 128, 64, 4, true, true, elem_type>, false>(params, stream);
|
|
||||||
// } else {
|
|
||||||
// run_flash_fwd<Flash_fwd_kernel_traits<64, 128, 64, 4, false, false, elem_type>, true>(params, stream);
|
|
||||||
// run_flash_fwd<Flash_fwd_kernel_traits<64, 128, 64, 4, true, true, elem_type>, true>(params, stream);
|
|
||||||
// run_flash_fwd<Flash_fwd_kernel_traits<64, 128, 64, 4, true, false, elem_type>, true>(params, stream);
|
|
||||||
// }
|
|
||||||
// }
|
|
||||||
template<>
|
template<>
|
||||||
void run_mha_fwd_<cutlass::half_t, 64>(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
void run_mha_fwd_<cutlass::half_t, 64>(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||||
run_mha_fwd_hdim64<cutlass::half_t>(params, stream);
|
run_mha_fwd_hdim64<cutlass::half_t>(params, stream);
|
||||||
|
@ -1,16 +1,9 @@
|
|||||||
// Copyright (c) 2023, Tri Dao.
|
// Copyright (c) 2023, Tri Dao.
|
||||||
|
|
||||||
// Splitting the different head dimensions to different files to speed up compilation.
|
// Splitting the different head dimensions to different files to speed up compilation.
|
||||||
|
// This file is auto-generated. See "generate_kernels.py"
|
||||||
|
|
||||||
#include "flash_fwd_launch_template.h"
|
#include "flash_fwd_launch_template.h"
|
||||||
|
|
||||||
// template<>
|
|
||||||
// void run_mha_fwd_<cutlass::bfloat16_t, 96>(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
|
||||||
// using elem_type = cutlass::bfloat16_t;
|
|
||||||
// BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
|
|
||||||
// run_flash_fwd<Flash_fwd_kernel_traits<96, 128, 64, 4, true, false, elem_type>, Is_dropout>(params, stream);
|
|
||||||
// });
|
|
||||||
// }
|
|
||||||
template<>
|
template<>
|
||||||
void run_mha_fwd_<cutlass::bfloat16_t, 96>(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
void run_mha_fwd_<cutlass::bfloat16_t, 96>(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||||
run_mha_fwd_hdim96<cutlass::bfloat16_t>(params, stream);
|
run_mha_fwd_hdim96<cutlass::bfloat16_t>(params, stream);
|
||||||
|
@ -1,23 +1,10 @@
|
|||||||
// Copyright (c) 2023, Tri Dao.
|
// Copyright (c) 2023, Tri Dao.
|
||||||
|
|
||||||
// Splitting the different head dimensions to different files to speed up compilation.
|
// Splitting the different head dimensions to different files to speed up compilation.
|
||||||
|
// This file is auto-generated. See "generate_kernels.py"
|
||||||
|
|
||||||
#include "flash_fwd_launch_template.h"
|
#include "flash_fwd_launch_template.h"
|
||||||
|
|
||||||
// template<>
|
template<>
|
||||||
// void run_mha_fwd_<cutlass::half_t, 96>(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
void run_mha_fwd_<cutlass::half_t, 96>(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||||
// using elem_type = cutlass::half_t;
|
|
||||||
// BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
|
|
||||||
// run_flash_fwd<Flash_fwd_kernel_traits<96, 128, 64, 4, true, false, elem_type>, Is_dropout>(params, stream);
|
|
||||||
// run_flash_fwd<Flash_fwd_kernel_traits<96, 128, 64, 4, true, true, elem_type>, Is_dropout>(params, stream);
|
|
||||||
// // This 3rd one is good for H100, and A100, A6000
|
|
||||||
// run_flash_fwd<Flash_fwd_kernel_traits<96, 128, 64, 4, false, false, elem_type>, Is_dropout>(params, stream);
|
|
||||||
// run_flash_fwd<Flash_fwd_kernel_traits<96, 128, 64, 4, false, true, elem_type>, Is_dropout>(params, stream);
|
|
||||||
// // These two are always slower
|
|
||||||
// // run_flash_fwd<Flash_fwd_kernel_traits<96, 128, 128, 4, true, elem_type>>(params, stream);
|
|
||||||
// // run_flash_fwd<Flash_fwd_kernel_traits<96, 64, 128, 4, true, elem_type>>(params, stream);
|
|
||||||
// });
|
|
||||||
// }
|
|
||||||
template<> void run_mha_fwd_<cutlass::half_t, 96>(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
|
||||||
run_mha_fwd_hdim96<cutlass::half_t>(params, stream);
|
run_mha_fwd_hdim96<cutlass::half_t>(params, stream);
|
||||||
}
|
}
|
@ -4,20 +4,18 @@
|
|||||||
|
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include <cmath>
|
|
||||||
#include <cute/algorithm/copy.hpp>
|
#include <cute/algorithm/copy.hpp>
|
||||||
#include <cute/algorithm/gemm.hpp>
|
|
||||||
|
|
||||||
#include <cutlass/cutlass.h>
|
#include <cutlass/cutlass.h>
|
||||||
#include <cutlass/array.h>
|
#include <cutlass/array.h>
|
||||||
#include <cutlass/numeric_types.h>
|
#include <cutlass/numeric_types.h>
|
||||||
#include <cutlass/numeric_conversion.h>
|
|
||||||
|
|
||||||
#include "block_info.h"
|
#include "block_info.h"
|
||||||
#include "kernel_traits.h"
|
#include "kernel_traits.h"
|
||||||
#include "utils.h"
|
#include "utils.h"
|
||||||
#include "softmax.h"
|
#include "softmax.h"
|
||||||
#include "philox.cuh"
|
|
||||||
|
#include "alibi.h"
|
||||||
|
|
||||||
namespace flash {
|
namespace flash {
|
||||||
|
|
||||||
@ -25,49 +23,6 @@ using namespace cute;
|
|||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
template <int MMA_M,
|
|
||||||
class... Args,
|
|
||||||
class TiledMMA>
|
|
||||||
CUTE_HOST_DEVICE
|
|
||||||
auto
|
|
||||||
make_tiled_copy_A_warpcontiguousM(Copy_Atom<Args...> const& copy_atom,
|
|
||||||
TiledMMA const& tiled_mma) {
|
|
||||||
using TileShape_MNK = typename TiledMMA::TiledShape_MNK;
|
|
||||||
using AtomShape_MNK = typename TiledMMA::AtomShape_MNK;
|
|
||||||
constexpr int AtomShape_M = decltype(size<0>(AtomShape_MNK{}))::value;
|
|
||||||
constexpr int kNWarps = decltype(size<0>(TileShape_MNK{}))::value / AtomShape_M;
|
|
||||||
constexpr int MMAStride_M = MMA_M * AtomShape_M;
|
|
||||||
auto t = make_tile(Layout<Shape<Int<AtomShape_M>, Int<kNWarps>>,
|
|
||||||
Stride<_1, Int<MMAStride_M>> >{},
|
|
||||||
make_layout(size<2>(TileShape_MNK{})));
|
|
||||||
// if (cute::thread0()) {printf("make_tiled_copy_A_warpcontiguousM "); print(t); printf("\n"); }
|
|
||||||
return make_tiled_copy_impl(copy_atom, tiled_mma.get_layoutA_TV(), t);
|
|
||||||
}
|
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
||||||
|
|
||||||
template <int MMA_M,
|
|
||||||
class... Args,
|
|
||||||
class TiledMMA>
|
|
||||||
CUTE_HOST_DEVICE
|
|
||||||
auto
|
|
||||||
make_tiled_copy_C_warpcontiguousM(Copy_Atom<Args...> const& copy_atom,
|
|
||||||
TiledMMA const& tiled_mma) {
|
|
||||||
using TileShape_MNK = typename TiledMMA::TiledShape_MNK;
|
|
||||||
using AtomShape_MNK = typename TiledMMA::AtomShape_MNK;
|
|
||||||
constexpr int AtomShape_M = decltype(size<0>(AtomShape_MNK{}))::value;
|
|
||||||
constexpr int kNWarps = decltype(size<0>(TileShape_MNK{}))::value / AtomShape_M;
|
|
||||||
constexpr int MMAStride_M = MMA_M * AtomShape_M;
|
|
||||||
auto t = make_tile(Layout<Shape<Int<AtomShape_M>, Int<kNWarps>>,
|
|
||||||
Stride<_1, Int<MMAStride_M>> >{},
|
|
||||||
// TODO: Shouldn't this be size<1>?
|
|
||||||
make_layout(size<2>(TileShape_MNK{})));
|
|
||||||
// if (cute::thread0()) {printf("make_tiled_copy_C_warpcontiguousM "); print(t); printf("\n"); }
|
|
||||||
return make_tiled_copy_impl(copy_atom, tiled_mma.get_layoutC_TV(), t);
|
|
||||||
}
|
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
||||||
|
|
||||||
template<bool Is_first, bool Check_inf=false, typename Tensor0, typename Tensor1, typename Tensor2>
|
template<bool Is_first, bool Check_inf=false, typename Tensor0, typename Tensor1, typename Tensor2>
|
||||||
inline __device__ void softmax_rescale_o(Tensor0 &scores, Tensor1 &scores_max, Tensor1 &scores_sum,
|
inline __device__ void softmax_rescale_o(Tensor0 &scores, Tensor1 &scores_max, Tensor1 &scores_sum,
|
||||||
Tensor2 &acc_o, float softmax_scale_log2) {
|
Tensor2 &acc_o, float softmax_scale_log2) {
|
||||||
@ -77,7 +32,7 @@ inline __device__ void softmax_rescale_o(Tensor0 &scores, Tensor1 &scores_max, T
|
|||||||
flash::reduce_sum(scores, scores_sum);
|
flash::reduce_sum(scores, scores_sum);
|
||||||
} else {
|
} else {
|
||||||
Tensor scores_max_prev = make_fragment_like(scores_max);
|
Tensor scores_max_prev = make_fragment_like(scores_max);
|
||||||
copy(scores_max, scores_max_prev);
|
cute::copy(scores_max, scores_max_prev);
|
||||||
flash::template reduce_max</*zero_init=*/false>(scores, scores_max);
|
flash::template reduce_max</*zero_init=*/false>(scores, scores_max);
|
||||||
// Reshape acc_o from (MMA=4, MMA_M, MMA_K) to (nrow=(2, MMA_M), ncol=(2, MMA_K))
|
// Reshape acc_o from (MMA=4, MMA_M, MMA_K) to (nrow=(2, MMA_M), ncol=(2, MMA_K))
|
||||||
Tensor acc_o_rowcol = make_tensor(acc_o.data(), flash::convert_layout_acc_rowcol(acc_o.layout()));
|
Tensor acc_o_rowcol = make_tensor(acc_o.data(), flash::convert_layout_acc_rowcol(acc_o.layout()));
|
||||||
@ -103,23 +58,22 @@ inline __device__ void softmax_rescale_o(Tensor0 &scores, Tensor1 &scores_max, T
|
|||||||
|
|
||||||
template<typename Engine0, typename Layout0, typename Engine1, typename Layout1, typename TiledCopy>
|
template<typename Engine0, typename Layout0, typename Engine1, typename Layout1, typename TiledCopy>
|
||||||
inline __device__ void write_softmax_to_gmem(
|
inline __device__ void write_softmax_to_gmem(
|
||||||
Tensor<Engine0, Layout0> const &tOrP, Tensor<Engine1, Layout1> &tPgP, TiledCopy gmem_thr_copy_P
|
Tensor<Engine0, Layout0> const &tOrP, Tensor<Engine1, Layout1> &tPgP, TiledCopy gmem_tiled_copy_P
|
||||||
) {
|
) {
|
||||||
// Reshape tOrP from (8, MMA_M, MMA_N) to (8, MMA_M * MMA_N)
|
// Reshape tOrP from (8, MMA_M, MMA_N) to (8, MMA_M * MMA_N)
|
||||||
Layout l = tOrP.layout();
|
Layout l = tOrP.layout();
|
||||||
Tensor tPrP = make_tensor(tOrP.data(), make_layout(get<0>(l), make_layout(get<1>(l), get<2>(l))));
|
Tensor tPrP = make_tensor(tOrP.data(), make_layout(get<0>(l), make_layout(get<1>(l), get<2>(l))));
|
||||||
CUTE_STATIC_ASSERT_V(size<2>(tPgP) == _1{});
|
CUTE_STATIC_ASSERT_V(size<2>(tPgP) == _1{});
|
||||||
// TODO(laurent): reactivate the following
|
CUTE_STATIC_ASSERT_V(size<1>(tPrP) == size<1>(tPgP));
|
||||||
// CUTE_STATIC_ASSERT_V(size<1>(tPrP) == size<1>(tPgP));
|
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int mi = 0; mi < size<1>(tPrP); ++mi) {
|
for (int mi = 0; mi < size<1>(tPrP); ++mi) {
|
||||||
copy(gmem_thr_copy_P, tPrP(_, mi), tPgP(_, mi, 0));
|
cute::copy(gmem_tiled_copy_P, tPrP(_, mi), tPgP(_, mi, 0));
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_even_N, bool Is_even_K, bool Return_softmax, typename Params>
|
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Return_softmax, typename Params>
|
||||||
inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bidb, const int bidh, const int m_block) {
|
inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bidb, const int bidh, const int m_block) {
|
||||||
|
|
||||||
using Element = typename Kernel_traits::Element;
|
using Element = typename Kernel_traits::Element;
|
||||||
@ -138,16 +92,65 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
|
|||||||
constexpr int kNWarps = Kernel_traits::kNWarps;
|
constexpr int kNWarps = Kernel_traits::kNWarps;
|
||||||
constexpr int MMA_M = kBlockM / decltype(size<0>(typename Kernel_traits::TiledMma::TiledShape_MNK{}))::value;
|
constexpr int MMA_M = kBlockM / decltype(size<0>(typename Kernel_traits::TiledMma::TiledShape_MNK{}))::value;
|
||||||
|
|
||||||
const BlockInfo</*Varlen=*/!Is_even_N> binfo(params, bidb);
|
const BlockInfo</*Varlen=*/!Is_even_MN> binfo(params, bidb);
|
||||||
if (m_block * kBlockM >= binfo.actual_seqlen_q || binfo.actual_seqlen_k == 0) return;
|
if (m_block * kBlockM >= binfo.actual_seqlen_q) return;
|
||||||
|
|
||||||
|
const int n_block_min = !Is_local ? 0 : std::max(0, (m_block * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q - params.window_size_left) / kBlockN);
|
||||||
int n_block_max = cute::ceil_div(binfo.actual_seqlen_k, kBlockN);
|
int n_block_max = cute::ceil_div(binfo.actual_seqlen_k, kBlockN);
|
||||||
if (Is_causal) {
|
if (Is_causal || Is_local) {
|
||||||
n_block_max = std::min(n_block_max, cute::ceil_div((m_block + 1) * kBlockM, kBlockN));
|
n_block_max = std::min(n_block_max,
|
||||||
|
cute::ceil_div((m_block + 1) * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q + params.window_size_right, kBlockN));
|
||||||
// if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) {
|
// if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) {
|
||||||
// printf("m_block = %d, n_block_max = %d\n", m_block, n_block_max);
|
// printf("m_block = %d, n_block_max = %d\n", m_block, n_block_max);
|
||||||
// }
|
// }
|
||||||
}
|
}
|
||||||
|
// We exit early and write 0 to gO and gLSE. This also covers the case where actual_seqlen_k == 0.
|
||||||
|
// Otherwise we might read OOB elements from gK and gV.
|
||||||
|
if ((Is_causal || Is_local || !Is_even_MN) && n_block_max <= n_block_min) {
|
||||||
|
// Save seed and offset for backward. If we don't have this here, the 0-th thread block might
|
||||||
|
// exit early and no one saves the rng state.
|
||||||
|
// if (Is_dropout && blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0 && tidx == 0) {
|
||||||
|
// auto seeds = at::cuda::philox::unpack(params.philox_args);
|
||||||
|
// params.rng_state[0] = std::get<0>(seeds);
|
||||||
|
// params.rng_state[1] = std::get<1>(seeds);
|
||||||
|
// params.rng_state[0] = 0;
|
||||||
|
// params.rng_state[1] = 0;
|
||||||
|
// }
|
||||||
|
const index_t row_offset_o = binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb)
|
||||||
|
+ m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride;
|
||||||
|
const index_t row_offset_lse = (bidb * params.h + bidh) * params.seqlen_q + m_block * kBlockM;
|
||||||
|
Tensor gO = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.o_ptr) + row_offset_o),
|
||||||
|
Shape<Int<kBlockM>, Int<kHeadDim>>{},
|
||||||
|
make_stride(params.o_row_stride, _1{}));
|
||||||
|
Tensor gLSE = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.softmax_lse_ptr) + row_offset_lse),
|
||||||
|
Shape<Int<kBlockM>>{}, Stride<_1>{});
|
||||||
|
|
||||||
|
typename Kernel_traits::GmemTiledCopyO gmem_tiled_copy_O;
|
||||||
|
auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(tidx);
|
||||||
|
Tensor tOgO = gmem_thr_copy_O.partition_D(gO);
|
||||||
|
Tensor tOrO = make_tensor<Element>(shape(tOgO));
|
||||||
|
clear(tOrO);
|
||||||
|
// Construct identity layout for sO
|
||||||
|
Tensor cO = make_identity_tensor(make_shape(size<0>(gO), size<1>(gO))); // (BLK_M,BLK_K) -> (blk_m,blk_k)
|
||||||
|
// Repeat the partitioning with identity layouts
|
||||||
|
Tensor tOcO = gmem_thr_copy_O.partition_D(cO);
|
||||||
|
Tensor tOpO = make_tensor<bool>(make_shape(size<2>(tOgO)));
|
||||||
|
if (!Is_even_K) {
|
||||||
|
#pragma unroll
|
||||||
|
for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(0, 0, k)) < params.d; }
|
||||||
|
}
|
||||||
|
// Clear_OOB_K must be false since we don't want to write zeros to gmem
|
||||||
|
flash::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
|
||||||
|
gmem_tiled_copy_O, tOrO, tOgO, tOcO, tOpO, binfo.actual_seqlen_q - m_block * kBlockM
|
||||||
|
);
|
||||||
|
#pragma unroll
|
||||||
|
for (int m = 0; m < size<1>(tOgO); ++m) {
|
||||||
|
const int row = get<0>(tOcO(0, m, 0));
|
||||||
|
if (row < binfo.actual_seqlen_q - m_block * kBlockM && get<1>(tOcO(0, m, 0)) == 0) { gLSE(row) = INFINITY; }
|
||||||
|
}
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
// if (tidx == 0) { printf("m_block = %d, n_block_min = %d, n_block_max = %d\n", m_block, n_block_min, n_block_max); }
|
||||||
|
|
||||||
// We iterate over the blocks in reverse order. This is because the last block is the only one
|
// We iterate over the blocks in reverse order. This is because the last block is the only one
|
||||||
// that needs masking when we read K and V from global memory. Moreover, iterating in reverse
|
// that needs masking when we read K and V from global memory. Moreover, iterating in reverse
|
||||||
@ -185,8 +188,10 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
|
|||||||
Tensor sVt = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposed{});
|
Tensor sVt = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposed{});
|
||||||
Tensor sVtNoSwizzle = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposedNoSwizzle{});
|
Tensor sVtNoSwizzle = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposedNoSwizzle{});
|
||||||
|
|
||||||
auto gmem_thr_copy_QKV = typename Kernel_traits::GmemTiledCopyQKV{}.get_thread_slice(tidx);
|
typename Kernel_traits::GmemTiledCopyQKV gmem_tiled_copy_QKV;
|
||||||
auto gmem_thr_copy_P = typename Kernel_traits::GmemTiledCopyP{}.get_thread_slice(tidx);
|
auto gmem_thr_copy_QKV = gmem_tiled_copy_QKV.get_thread_slice(tidx);
|
||||||
|
typename Kernel_traits::GmemTiledCopyP gmem_tiled_copy_P;
|
||||||
|
auto gmem_thr_copy_P = gmem_tiled_copy_P.get_thread_slice(tidx);
|
||||||
|
|
||||||
Tensor tQgQ = gmem_thr_copy_QKV.partition_S(gQ);
|
Tensor tQgQ = gmem_thr_copy_QKV.partition_S(gQ);
|
||||||
Tensor tQsQ = gmem_thr_copy_QKV.partition_D(sQ);
|
Tensor tQsQ = gmem_thr_copy_QKV.partition_D(sQ);
|
||||||
@ -208,16 +213,18 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
|
|||||||
// Copy Atom retiling
|
// Copy Atom retiling
|
||||||
//
|
//
|
||||||
|
|
||||||
auto smem_thr_copy_Q = make_tiled_copy_A(typename Kernel_traits::SmemCopyAtom{}, tiled_mma).get_thread_slice(tidx);
|
auto smem_tiled_copy_Q = make_tiled_copy_A(typename Kernel_traits::SmemCopyAtom{}, tiled_mma);
|
||||||
// auto smem_thr_copy_Q = make_tiled_copy_A_warpcontiguousM<MMA_M>(typename Kernel_traits::SmemCopyAtom{}, tiled_mma).get_thread_slice(tidx);
|
auto smem_thr_copy_Q = smem_tiled_copy_Q.get_thread_slice(tidx);
|
||||||
// if (cute::thread0()) {smem_thr_copy_Q.print_all();}
|
// if (cute::thread0()) {smem_thr_copy_Q.print_all();}
|
||||||
Tensor tSsQ = smem_thr_copy_Q.partition_S(sQ);
|
Tensor tSsQ = smem_thr_copy_Q.partition_S(sQ);
|
||||||
// if (cute::thread0()) {print(tSsQ.layout()); printf("\n");}
|
// if (cute::thread0()) {print(tSsQ.layout()); printf("\n");}
|
||||||
|
|
||||||
auto smem_thr_copy_K = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtom{}, tiled_mma).get_thread_slice(tidx);
|
auto smem_tiled_copy_K = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtom{}, tiled_mma);
|
||||||
|
auto smem_thr_copy_K = smem_tiled_copy_K.get_thread_slice(tidx);
|
||||||
Tensor tSsK = smem_thr_copy_K.partition_S(sK);
|
Tensor tSsK = smem_thr_copy_K.partition_S(sK);
|
||||||
|
|
||||||
auto smem_thr_copy_V = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtomTransposed{}, tiled_mma).get_thread_slice(tidx);
|
auto smem_tiled_copy_V = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtomTransposed{}, tiled_mma);
|
||||||
|
auto smem_thr_copy_V = smem_tiled_copy_V.get_thread_slice(tidx);
|
||||||
Tensor tOsVt = smem_thr_copy_V.partition_S(sVt);
|
Tensor tOsVt = smem_thr_copy_V.partition_S(sVt);
|
||||||
|
|
||||||
// TODO: this might need to change if we change the mma instruction in SM70
|
// TODO: this might need to change if we change the mma instruction in SM70
|
||||||
@ -268,7 +275,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
|
|||||||
|
|
||||||
Tensor tQrQ = make_fragment_like(tQgQ);
|
Tensor tQrQ = make_fragment_like(tQgQ);
|
||||||
// We don't need to clear the sQ smem tiles since we'll only write out the valid outputs
|
// We don't need to clear the sQ smem tiles since we'll only write out the valid outputs
|
||||||
flash::copy</*Is_even_MN=*/false, Is_even_K>(gmem_thr_copy_QKV, tQgQ, tQsQ, tQcQ, tQpQ,
|
flash::copy<Is_even_MN, Is_even_K>(gmem_tiled_copy_QKV, tQgQ, tQsQ, tQcQ, tQpQ,
|
||||||
binfo.actual_seqlen_q - m_block * kBlockM);
|
binfo.actual_seqlen_q - m_block * kBlockM);
|
||||||
if (Kernel_traits::Is_Q_in_regs) { cute::cp_async_fence(); }
|
if (Kernel_traits::Is_Q_in_regs) { cute::cp_async_fence(); }
|
||||||
|
|
||||||
@ -285,13 +292,13 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
|
|||||||
__syncthreads();
|
__syncthreads();
|
||||||
Tensor tSrQ_copy_view = smem_thr_copy_Q.retile_D(tSrQ);
|
Tensor tSrQ_copy_view = smem_thr_copy_Q.retile_D(tSrQ);
|
||||||
CUTE_STATIC_ASSERT_V(size<1>(tSsQ) == size<1>(tSrQ_copy_view)); // M
|
CUTE_STATIC_ASSERT_V(size<1>(tSsQ) == size<1>(tSrQ_copy_view)); // M
|
||||||
copy(smem_thr_copy_Q, tSsQ, tSrQ_copy_view);
|
cute::copy(smem_tiled_copy_Q, tSsQ, tSrQ_copy_view);
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
}
|
}
|
||||||
|
|
||||||
int n_block = n_block_max - 1;
|
int n_block = n_block_max - 1;
|
||||||
// We don't need to clear the sK smem tiles since we'll mask out the scores anyway.
|
// We don't need to clear the sK smem tiles since we'll mask out the scores anyway.
|
||||||
flash::copy<Is_even_N, Is_even_K>(gmem_thr_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV,
|
flash::copy<Is_even_MN, Is_even_K>(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV,
|
||||||
binfo.actual_seqlen_k - n_block * kBlockN);
|
binfo.actual_seqlen_k - n_block * kBlockN);
|
||||||
cute::cp_async_fence();
|
cute::cp_async_fence();
|
||||||
// if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z < 2) { print(tKgK); }
|
// if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z < 2) { print(tKgK); }
|
||||||
@ -302,7 +309,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
|
|||||||
__syncthreads();
|
__syncthreads();
|
||||||
Tensor tSrQ_copy_view = smem_thr_copy_Q.retile_D(tSrQ);
|
Tensor tSrQ_copy_view = smem_thr_copy_Q.retile_D(tSrQ);
|
||||||
CUTE_STATIC_ASSERT_V(size<1>(tSsQ) == size<1>(tSrQ_copy_view)); // M
|
CUTE_STATIC_ASSERT_V(size<1>(tSsQ) == size<1>(tSrQ_copy_view)); // M
|
||||||
copy(smem_thr_copy_Q, tSsQ, tSrQ_copy_view);
|
cute::copy(smem_tiled_copy_Q, tSsQ, tSrQ_copy_view);
|
||||||
}
|
}
|
||||||
|
|
||||||
// auto seeds = at::cuda::philox::unpack(params.philox_args);
|
// auto seeds = at::cuda::philox::unpack(params.philox_args);
|
||||||
@ -313,13 +320,19 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
|
|||||||
|
|
||||||
clear(acc_o);
|
clear(acc_o);
|
||||||
|
|
||||||
|
float alibi_slope = !Has_alibi ? 0.0f : reinterpret_cast<float *>(params.alibi_slopes_ptr)[bidb * params.alibi_slopes_batch_stride + bidh] / params.scale_softmax;
|
||||||
|
|
||||||
// For performance reason, we separate out two kinds of iterations:
|
// For performance reason, we separate out two kinds of iterations:
|
||||||
// those that need masking on S, and those that don't.
|
// those that need masking on S, and those that don't.
|
||||||
// We need masking on S for the very last block when K and V has length not multiple of kBlockN.
|
// We need masking on S for the very last block when K and V has length not multiple of kBlockN.
|
||||||
// We also need masking on S if it's causal, for the last ceil_div(kBlockM, kBlockN) blocks.
|
// We also need masking on S if it's causal, for the last ceil_div(kBlockM, kBlockN) blocks.
|
||||||
// We will have at least 1 "masking" iteration.
|
// We will have at least 1 "masking" iteration.
|
||||||
|
|
||||||
constexpr int n_masking_steps = Is_causal ? cute::ceil_div(kBlockM, kBlockN) : 1;
|
// If not even_N, then seqlen_k might end in the middle of a block. In that case we need to
|
||||||
|
// mask 2 blocks (e.g. when kBlockM == kBlockN), not just 1.
|
||||||
|
constexpr int n_masking_steps = (!Is_causal && !Is_local)
|
||||||
|
? 1
|
||||||
|
: ((Is_even_MN && Is_causal) ? cute::ceil_div(kBlockM, kBlockN) : cute::ceil_div(kBlockM, kBlockN) + 1);
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int masking_step = 0; masking_step < n_masking_steps; ++masking_step, --n_block) {
|
for (int masking_step = 0; masking_step < n_masking_steps; ++masking_step, --n_block) {
|
||||||
Tensor acc_s = partition_fragment_C(tiled_mma, Shape<Int<kBlockM>, Int<kBlockN>>{}); // (MMA=4, MMA_M, MMA_N)
|
Tensor acc_s = partition_fragment_C(tiled_mma, Shape<Int<kBlockM>, Int<kBlockN>>{}); // (MMA=4, MMA_M, MMA_N)
|
||||||
@ -330,28 +343,42 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
|
|||||||
// Advance gV
|
// Advance gV
|
||||||
if (masking_step > 0) {
|
if (masking_step > 0) {
|
||||||
tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride));
|
tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride));
|
||||||
flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_thr_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV);
|
flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV);
|
||||||
} else {
|
} else {
|
||||||
// Clear the smem tiles to account for predicated off loads
|
// Clear the smem tiles to account for predicated off loads
|
||||||
flash::copy<Is_even_N, Is_even_K, /*Clear_OOB_MN=*/true>(
|
flash::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/true>(
|
||||||
gmem_thr_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN
|
gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
cute::cp_async_fence();
|
cute::cp_async_fence();
|
||||||
|
|
||||||
flash::gemm</*A_in_regs=*/Kernel_traits::Is_Q_in_regs>(
|
flash::gemm</*A_in_regs=*/Kernel_traits::Is_Q_in_regs>(
|
||||||
acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_thr_copy_Q, smem_thr_copy_K
|
acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_tiled_copy_Q, smem_tiled_copy_K,
|
||||||
|
smem_thr_copy_Q, smem_thr_copy_K
|
||||||
);
|
);
|
||||||
// if (cute::thread0()) { print(acc_s); }
|
// if (cute::thread0()) { print(acc_s); }
|
||||||
|
|
||||||
// Reshape acc_s from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N))
|
// Reshape acc_s from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N))
|
||||||
Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout()));
|
Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout()));
|
||||||
// if (cute::thread0()) { print(scores); }
|
// if (cute::thread0()) { print_tensor(scores); }
|
||||||
// We don't put the masking before the matmul S = Q K^T because we don't clear sK
|
// We don't put the masking before the matmul S = Q K^T because we don't clear sK
|
||||||
// for rows outside actual_seqlen_k. So those rows could have Inf / NaN, and the matmul
|
// for rows outside actual_seqlen_k. So those rows could have Inf / NaN, and the matmul
|
||||||
// can produce Inf / NaN.
|
// can produce Inf / NaN.
|
||||||
if (!Is_causal) {
|
|
||||||
if (!Is_even_N) { flash::apply_mask(scores, binfo.actual_seqlen_k - n_block * kBlockN); }
|
if (Has_alibi) {
|
||||||
|
flash::apply_alibi<Is_causal>(
|
||||||
|
scores,
|
||||||
|
n_block * kBlockN,
|
||||||
|
binfo.actual_seqlen_k,
|
||||||
|
m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4,
|
||||||
|
binfo.actual_seqlen_q,
|
||||||
|
kNWarps * 16,
|
||||||
|
alibi_slope
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!Is_causal && !Is_local) {
|
||||||
|
if (!Is_even_MN) { flash::apply_mask(scores, binfo.actual_seqlen_k - n_block * kBlockN); }
|
||||||
} else {
|
} else {
|
||||||
// Tensor caccS = make_identity_tensor(Shape<Int<kBlockM>, Int<kBlockN>>{}); // (BLK_M,BLK_N) -> (blk_m,blk_n)
|
// Tensor caccS = make_identity_tensor(Shape<Int<kBlockM>, Int<kBlockN>>{}); // (BLK_M,BLK_N) -> (blk_m,blk_n)
|
||||||
// Tensor taccScS = thr_mma.partition_C(caccS); // (MMA,MMA_M,MMA_N)
|
// Tensor taccScS = thr_mma.partition_C(caccS); // (MMA,MMA_M,MMA_N)
|
||||||
@ -364,20 +391,24 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
|
|||||||
// Idk why it's get<1> and not get<0> of the stride.
|
// Idk why it's get<1> and not get<0> of the stride.
|
||||||
// if (cute::thread0()) { print(idx_row.layout()); print(stride<1>(idx_row)); printf("stride = %d \n", get<1>(stride<1>(idx_row))); }
|
// if (cute::thread0()) { print(idx_row.layout()); print(stride<1>(idx_row)); printf("stride = %d \n", get<1>(stride<1>(idx_row))); }
|
||||||
// I can't get the stride from idx_row
|
// I can't get the stride from idx_row
|
||||||
flash::apply_mask_causal(scores, n_block * kBlockN, binfo.actual_seqlen_k,
|
flash::apply_mask_local</*HasWSLeft=*/Is_local>(
|
||||||
|
scores, n_block * kBlockN, binfo.actual_seqlen_k,
|
||||||
// m_block * kBlockM + get<0>(idx_row(0)),
|
// m_block * kBlockM + get<0>(idx_row(0)),
|
||||||
m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4,
|
m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4,
|
||||||
kNWarps * 16);
|
binfo.actual_seqlen_q, kNWarps * 16,
|
||||||
// m_block * kBlockM + (tidx / 32) * 16, kNWarps * 16);
|
params.window_size_left, params.window_size_right
|
||||||
// m_block * kBlockM + (tidx / 32) * (kBlockM / kNWarps), 16);
|
// m_block * kBlockM + (tidx / 32) * 16, kNWarps * 16
|
||||||
|
// m_block * kBlockM + (tidx / 32) * (kBlockM / kNWarps), 16
|
||||||
|
);
|
||||||
|
// if (cute::thread0()) { print_tensor(scores); }
|
||||||
}
|
}
|
||||||
|
|
||||||
flash::cp_async_wait<0>();
|
flash::cp_async_wait<0>();
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
if (n_block > 0) {
|
if (n_block > n_block_min) {
|
||||||
// Advance gK
|
// Advance gK
|
||||||
tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride));
|
tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride));
|
||||||
flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_thr_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV);
|
flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV);
|
||||||
// This cp_async_fence needs to be in the if block, otherwise the synchronization
|
// This cp_async_fence needs to be in the if block, otherwise the synchronization
|
||||||
// isn't right and we get race conditions.
|
// isn't right and we get race conditions.
|
||||||
cute::cp_async_fence();
|
cute::cp_async_fence();
|
||||||
@ -385,24 +416,24 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
|
|||||||
|
|
||||||
// TODO: when we have key_padding_mask we'll need to Check_inf
|
// TODO: when we have key_padding_mask we'll need to Check_inf
|
||||||
masking_step == 0
|
masking_step == 0
|
||||||
? softmax_rescale_o</*Is_first=*/true, /*Check_inf=*/Is_causal>(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2)
|
? softmax_rescale_o</*Is_first=*/true, /*Check_inf=*/Is_causal || Is_local>(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2)
|
||||||
: softmax_rescale_o</*Is_first=*/false, /*Check_inf=*/Is_causal>(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2);
|
: softmax_rescale_o</*Is_first=*/false, /*Check_inf=*/Is_causal || Is_local>(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2);
|
||||||
|
|
||||||
// Convert scores from fp32 to fp16/bf16
|
// Convert scores from fp32 to fp16/bf16
|
||||||
Tensor rP = flash::convert_type<Element>(scores);
|
Tensor rP = flash::convert_type<Element>(scores);
|
||||||
// Reshape rP from (nrow=(2, MMA_M), ncol=(2, MMA_N)) to ((2, 2, 2), MMA_M, MMA_N / 2)
|
// Reshape rP from (nrow=(2, MMA_M), ncol=(2, MMA_N)) to ((2, 2, 2), MMA_M, MMA_N / 2)
|
||||||
// if using m16n8k16 or ((2, 2, 1), MMA_M, MMA_N) if using m16n8k8.
|
// if using m16n8k16 or ((2, 2, 1), MMA_M, MMA_N) if using m16n8k8.
|
||||||
Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_rowcol_Aregs<Kernel_traits::TiledMma>(rP.layout()));
|
Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_rowcol_Aregs<Kernel_traits::TiledMma>(rP.layout()));
|
||||||
uint32_t block_row_idx = m_block * (kBlockM / 16) + tidx / 32;
|
int block_row_idx = m_block * (kBlockM / 16) + tidx / 32;
|
||||||
uint32_t block_col_idx = n_block * (kBlockN / 32);
|
int block_col_idx = n_block * (kBlockN / 32);
|
||||||
if (Return_softmax) {
|
if (Return_softmax) {
|
||||||
Tensor tOrP_copy = make_fragment_like(tOrP);
|
Tensor tOrP_copy = make_fragment_like(tOrP);
|
||||||
copy(tOrP, tOrP_copy);
|
cute::copy(tOrP, tOrP_copy);
|
||||||
flash::apply_dropout</*encode_dropout_in_sign_bit=*/true>(
|
flash::apply_dropout</*encode_dropout_in_sign_bit=*/true>(
|
||||||
tOrP_copy, params.p_dropout_in_uint8_t, seed, offset,
|
tOrP_copy, params.p_dropout_in_uint8_t, seed, offset,
|
||||||
block_row_idx, block_col_idx, kNWarps
|
block_row_idx, block_col_idx, kNWarps
|
||||||
);
|
);
|
||||||
flash::write_softmax_to_gmem(tOrP_copy, tPgP, gmem_thr_copy_P);
|
flash::write_softmax_to_gmem(tOrP_copy, tPgP, gmem_tiled_copy_P);
|
||||||
tPgP.data() = tPgP.data() + (-kBlockN);
|
tPgP.data() = tPgP.data() + (-kBlockN);
|
||||||
}
|
}
|
||||||
if (Is_dropout) {
|
if (Is_dropout) {
|
||||||
@ -411,37 +442,38 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
|
|||||||
}
|
}
|
||||||
// if (cute::thread0()) { print(tOrP); }
|
// if (cute::thread0()) { print(tOrP); }
|
||||||
|
|
||||||
flash::gemm_A_in_regs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_thr_copy_V);
|
flash::gemm_A_in_regs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V);
|
||||||
// if (cute::thread0()) { print(scores); }
|
// if (cute::thread0()) { print(scores); }
|
||||||
|
|
||||||
// This check is at the end of the loop since we always have at least 1 iteration
|
// This check is at the end of the loop since we always have at least 1 iteration
|
||||||
if (n_masking_steps > 1 && n_block <= 0) {
|
if (n_masking_steps > 1 && n_block <= n_block_min) {
|
||||||
--n_block;
|
--n_block;
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// These are the iterations where we don't need masking on S
|
// These are the iterations where we don't need masking on S
|
||||||
for (; n_block >= 0; --n_block) {
|
for (; n_block >= n_block_min; --n_block) {
|
||||||
Tensor acc_s = partition_fragment_C(tiled_mma, Shape<Int<kBlockM>, Int<kBlockN>>{}); // (MMA=4, MMA_M, MMA_N)
|
Tensor acc_s = partition_fragment_C(tiled_mma, Shape<Int<kBlockM>, Int<kBlockN>>{}); // (MMA=4, MMA_M, MMA_N)
|
||||||
clear(acc_s);
|
clear(acc_s);
|
||||||
flash::cp_async_wait<0>();
|
flash::cp_async_wait<0>();
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
// Advance gV
|
// Advance gV
|
||||||
tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride));
|
tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride));
|
||||||
flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_thr_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV);
|
flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV);
|
||||||
cute::cp_async_fence();
|
cute::cp_async_fence();
|
||||||
|
|
||||||
flash::gemm</*A_in_regs=*/Kernel_traits::Is_Q_in_regs>(
|
flash::gemm</*A_in_regs=*/Kernel_traits::Is_Q_in_regs>(
|
||||||
acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_thr_copy_Q, smem_thr_copy_K
|
acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_tiled_copy_Q, smem_tiled_copy_K,
|
||||||
|
smem_thr_copy_Q, smem_thr_copy_K
|
||||||
);
|
);
|
||||||
|
|
||||||
flash::cp_async_wait<0>();
|
flash::cp_async_wait<0>();
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
if (n_block > 0) {
|
if (n_block > n_block_min) {
|
||||||
// Advance gK
|
// Advance gK
|
||||||
tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride));
|
tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride));
|
||||||
flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_thr_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV);
|
flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV);
|
||||||
// This cp_async_fence needs to be in the if block, otherwise the synchronization
|
// This cp_async_fence needs to be in the if block, otherwise the synchronization
|
||||||
// isn't right and we get race conditions.
|
// isn't right and we get race conditions.
|
||||||
cute::cp_async_fence();
|
cute::cp_async_fence();
|
||||||
@ -449,22 +481,44 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
|
|||||||
|
|
||||||
// Reshape acc_s from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N))
|
// Reshape acc_s from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N))
|
||||||
Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout()));
|
Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout()));
|
||||||
softmax_rescale_o</*Is_first=*/false>(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2);
|
|
||||||
|
if (Has_alibi) {
|
||||||
|
flash::apply_alibi<Is_causal>(
|
||||||
|
scores,
|
||||||
|
n_block * kBlockN,
|
||||||
|
binfo.actual_seqlen_k,
|
||||||
|
m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4,
|
||||||
|
binfo.actual_seqlen_q,
|
||||||
|
kNWarps * 16,
|
||||||
|
alibi_slope
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (Is_local && n_block * kBlockN < (m_block + 1) * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q + params.window_size_right) {
|
||||||
|
flash::apply_mask_local(
|
||||||
|
scores, n_block * kBlockN, binfo.actual_seqlen_k,
|
||||||
|
m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4,
|
||||||
|
binfo.actual_seqlen_q, kNWarps * 16,
|
||||||
|
params.window_size_left, params.window_size_right
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
softmax_rescale_o</*Is_first=*/false, /*Check_inf=*/Is_local>(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2);
|
||||||
|
|
||||||
Tensor rP = flash::convert_type<Element>(scores);
|
Tensor rP = flash::convert_type<Element>(scores);
|
||||||
// Reshape rP from (nrow=(2, MMA_M), ncol=(2, MMA_N)) to ((2, 2, 2), MMA_M, MMA_N / 2)
|
// Reshape rP from (nrow=(2, MMA_M), ncol=(2, MMA_N)) to ((2, 2, 2), MMA_M, MMA_N / 2)
|
||||||
// if using m16n8k16 or ((2, 2, 1), MMA_M, MMA_N) if using m16n8k8.
|
// if using m16n8k16 or ((2, 2, 1), MMA_M, MMA_N) if using m16n8k8.
|
||||||
Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_rowcol_Aregs<Kernel_traits::TiledMma>(rP.layout()));
|
Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_rowcol_Aregs<Kernel_traits::TiledMma>(rP.layout()));
|
||||||
uint32_t block_row_idx = m_block * (kBlockM / 16) + tidx / 32;
|
int block_row_idx = m_block * (kBlockM / 16) + tidx / 32;
|
||||||
uint32_t block_col_idx = n_block * (kBlockN / 32);
|
int block_col_idx = n_block * (kBlockN / 32);
|
||||||
if (Return_softmax) {
|
if (Return_softmax) {
|
||||||
Tensor tOrP_copy = make_fragment_like(tOrP);
|
Tensor tOrP_copy = make_fragment_like(tOrP);
|
||||||
copy(tOrP, tOrP_copy);
|
cute::copy(tOrP, tOrP_copy);
|
||||||
flash::apply_dropout</*encode_dropout_in_sign_bit=*/true>(
|
flash::apply_dropout</*encode_dropout_in_sign_bit=*/true>(
|
||||||
tOrP_copy, params.p_dropout_in_uint8_t, seed, offset,
|
tOrP_copy, params.p_dropout_in_uint8_t, seed, offset,
|
||||||
block_row_idx, block_col_idx, kNWarps
|
block_row_idx, block_col_idx, kNWarps
|
||||||
);
|
);
|
||||||
flash::write_softmax_to_gmem(tOrP_copy, tPgP, gmem_thr_copy_P);
|
flash::write_softmax_to_gmem(tOrP_copy, tPgP, gmem_tiled_copy_P);
|
||||||
tPgP.data() = tPgP.data() + (-kBlockN);
|
tPgP.data() = tPgP.data() + (-kBlockN);
|
||||||
}
|
}
|
||||||
if (Is_dropout) {
|
if (Is_dropout) {
|
||||||
@ -472,7 +526,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
|
|||||||
block_row_idx, block_col_idx, kNWarps);
|
block_row_idx, block_col_idx, kNWarps);
|
||||||
}
|
}
|
||||||
|
|
||||||
flash::gemm_A_in_regs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_thr_copy_V);
|
flash::gemm_A_in_regs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Epilogue
|
// Epilogue
|
||||||
@ -496,15 +550,15 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
|
|||||||
Tensor rO = flash::convert_type<Element>(acc_o);
|
Tensor rO = flash::convert_type<Element>(acc_o);
|
||||||
Tensor sO = make_tensor(sQ.data(), typename Kernel_traits::SmemLayoutO{}); // (SMEM_M,SMEM_N)
|
Tensor sO = make_tensor(sQ.data(), typename Kernel_traits::SmemLayoutO{}); // (SMEM_M,SMEM_N)
|
||||||
// Partition sO to match the accumulator partitioning
|
// Partition sO to match the accumulator partitioning
|
||||||
auto smem_thr_copy_O = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomO{}, tiled_mma).get_thread_slice(tidx);
|
auto smem_tiled_copy_O = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomO{}, tiled_mma);
|
||||||
// auto smem_thr_copy_O = make_tiled_copy_C_warpcontiguousM<MMA_M>(typename Kernel_traits::SmemCopyAtomO{}, tiled_mma).get_thread_slice(tidx);
|
auto smem_thr_copy_O = smem_tiled_copy_O.get_thread_slice(tidx);
|
||||||
Tensor taccOrO = smem_thr_copy_O.retile_S(rO); // ((Atom,AtomNum), MMA_M, MMA_N)
|
Tensor taccOrO = smem_thr_copy_O.retile_S(rO); // ((Atom,AtomNum), MMA_M, MMA_N)
|
||||||
Tensor taccOsO = smem_thr_copy_O.partition_D(sO); // ((Atom,AtomNum),PIPE_M,PIPE_N)
|
Tensor taccOsO = smem_thr_copy_O.partition_D(sO); // ((Atom,AtomNum),PIPE_M,PIPE_N)
|
||||||
|
|
||||||
// sO has the same size as sQ, so we don't need to sync here.
|
// sO has the same size as sQ, so we don't need to sync here.
|
||||||
if (Kernel_traits::Share_Q_K_smem) { __syncthreads(); }
|
if (Kernel_traits::Share_Q_K_smem) { __syncthreads(); }
|
||||||
|
|
||||||
copy(smem_thr_copy_O, taccOrO, taccOsO);
|
cute::copy(smem_tiled_copy_O, taccOrO, taccOsO);
|
||||||
|
|
||||||
const index_t row_offset_o = binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb)
|
const index_t row_offset_o = binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb)
|
||||||
+ m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride;
|
+ m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride;
|
||||||
@ -515,14 +569,15 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
|
|||||||
Tensor gLSE = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.softmax_lse_ptr) + row_offset_lse),
|
Tensor gLSE = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.softmax_lse_ptr) + row_offset_lse),
|
||||||
Shape<Int<kBlockM>>{}, Stride<_1>{});
|
Shape<Int<kBlockM>>{}, Stride<_1>{});
|
||||||
|
|
||||||
auto gmem_thr_copy_O = typename Kernel_traits::GmemTiledCopyO{}.get_thread_slice(tidx);
|
typename Kernel_traits::GmemTiledCopyO gmem_tiled_copy_O;
|
||||||
|
auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(tidx);
|
||||||
Tensor tOsO = gmem_thr_copy_O.partition_S(sO); // ((Atom,AtomNum),ATOM_M,ATOM_N)
|
Tensor tOsO = gmem_thr_copy_O.partition_S(sO); // ((Atom,AtomNum),ATOM_M,ATOM_N)
|
||||||
Tensor tOgO = gmem_thr_copy_O.partition_D(gO);
|
Tensor tOgO = gmem_thr_copy_O.partition_D(gO);
|
||||||
|
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
|
|
||||||
Tensor tOrO = make_tensor<Element>(shape(tOgO));
|
Tensor tOrO = make_tensor<Element>(shape(tOgO));
|
||||||
copy(gmem_thr_copy_O, tOsO, tOrO);
|
cute::copy(gmem_tiled_copy_O, tOsO, tOrO);
|
||||||
|
|
||||||
Tensor caccO = make_identity_tensor(Shape<Int<kBlockM>, Int<kHeadDim>>{}); // (BLK_M,BLK_K) -> (blk_m,blk_k)
|
Tensor caccO = make_identity_tensor(Shape<Int<kBlockM>, Int<kHeadDim>>{}); // (BLK_M,BLK_K) -> (blk_m,blk_k)
|
||||||
Tensor taccOcO = thr_mma.partition_C(caccO); // (MMA,MMA_M,MMA_K)
|
Tensor taccOcO = thr_mma.partition_C(caccO); // (MMA,MMA_M,MMA_K)
|
||||||
@ -548,14 +603,15 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
|
|||||||
for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(0, 0, k)) < params.d; }
|
for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(0, 0, k)) < params.d; }
|
||||||
}
|
}
|
||||||
// Clear_OOB_K must be false since we don't want to write zeros to gmem
|
// Clear_OOB_K must be false since we don't want to write zeros to gmem
|
||||||
flash::copy</*Is_even_MN=*/false, Is_even_K, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
|
flash::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
|
||||||
gmem_thr_copy_O, tOrO, tOgO, tOcO, tOpO, binfo.actual_seqlen_q - m_block * kBlockM
|
gmem_tiled_copy_O, tOrO, tOgO, tOcO, tOpO, binfo.actual_seqlen_q - m_block * kBlockM
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_even_N, bool Is_even_K, bool Return_softmax, typename Params>
|
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Return_softmax, typename Params>
|
||||||
inline __device__ void compute_attn(const Params ¶ms) {
|
inline __device__ void compute_attn(const Params ¶ms) {
|
||||||
const int m_block = blockIdx.x;
|
const int m_block = blockIdx.x;
|
||||||
// The block index for the batch.
|
// The block index for the batch.
|
||||||
@ -571,7 +627,7 @@ inline __device__ void compute_attn(const Params ¶ms) {
|
|||||||
// the attention matrix. This way, as long as we have the batch, head, and the location of
|
// the attention matrix. This way, as long as we have the batch, head, and the location of
|
||||||
// the 16 x 32 block within the attention matrix, we can generate the exact same dropout pattern.
|
// the 16 x 32 block within the attention matrix, we can generate the exact same dropout pattern.
|
||||||
|
|
||||||
flash::compute_attn_1rowblock<Kernel_traits, Is_dropout, Is_causal, Is_even_N, Is_even_K, Return_softmax>(params, bidb, bidh, m_block);
|
flash::compute_attn_1rowblock<Kernel_traits, Is_dropout, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K, Return_softmax>(params, bidb, bidh, m_block);
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
@ -4,15 +4,14 @@
|
|||||||
|
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
// #include <ATen/cuda/CUDAContext.h>
|
|
||||||
|
|
||||||
#include "static_switch.h"
|
#include "static_switch.h"
|
||||||
#include "flash.h"
|
#include "flash.h"
|
||||||
#include "flash_fwd_kernel.h"
|
#include "flash_fwd_kernel.h"
|
||||||
|
|
||||||
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_even_N, bool Is_even_K, bool Return_softmax>
|
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Return_softmax>
|
||||||
__global__ void flash_fwd_kernel(Flash_fwd_params params) {
|
__global__ void flash_fwd_kernel(Flash_fwd_params params) {
|
||||||
flash::compute_attn<Kernel_traits, Is_dropout, Is_causal, Is_even_N, Is_even_K, Return_softmax>(params);
|
static_assert(!(Is_causal && Is_local)); // If Is_local is true, Is_causal should be false
|
||||||
|
flash::compute_attn<Kernel_traits, Is_dropout, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K, Return_softmax>(params);
|
||||||
}
|
}
|
||||||
|
|
||||||
template<typename Kernel_traits, bool Is_dropout, bool Is_causal>
|
template<typename Kernel_traits, bool Is_dropout, bool Is_causal>
|
||||||
@ -26,35 +25,39 @@ void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
|||||||
|
|
||||||
const int num_m_block = (params.seqlen_q + Kernel_traits::kBlockM - 1) / Kernel_traits::kBlockM;
|
const int num_m_block = (params.seqlen_q + Kernel_traits::kBlockM - 1) / Kernel_traits::kBlockM;
|
||||||
dim3 grid(num_m_block, params.b, params.h);
|
dim3 grid(num_m_block, params.b, params.h);
|
||||||
// We also use is_even_N to set Unpadded in the BlockInfo constructor, so we need to check
|
const bool is_even_MN = params.cu_seqlens_q == nullptr && params.cu_seqlens_k == nullptr && params.seqlen_k % Kernel_traits::kBlockN == 0 && params.seqlen_q % Kernel_traits::kBlockM == 0;
|
||||||
// for cu_seqlens_q as well.
|
|
||||||
const bool is_even_N = params.cu_seqlens_q == nullptr && params.cu_seqlens_k == nullptr && params.seqlen_k % Kernel_traits::kBlockN == 0;
|
|
||||||
const bool is_even_K = params.d == Kernel_traits::kHeadDim;
|
const bool is_even_K = params.d == Kernel_traits::kHeadDim;
|
||||||
const bool return_softmax = params.p_ptr != nullptr;
|
const bool return_softmax = params.p_ptr != nullptr;
|
||||||
BOOL_SWITCH(is_even_N, IsEvenNConst, [&] {
|
BOOL_SWITCH(is_even_MN, IsEvenMNConst, [&] {
|
||||||
BOOL_SWITCH(is_even_K, IsEvenKConst, [&] {
|
BOOL_SWITCH(is_even_K, IsEvenKConst, [&] {
|
||||||
|
BOOL_SWITCH((params.window_size_left >= 0 || params.window_size_right >= 0) && !Is_causal, Is_local, [&] {
|
||||||
BOOL_SWITCH(return_softmax, ReturnSoftmaxConst, [&] {
|
BOOL_SWITCH(return_softmax, ReturnSoftmaxConst, [&] {
|
||||||
|
BOOL_SWITCH(params.alibi_slopes_ptr != nullptr, Has_alibi, [&] {
|
||||||
// Will only return softmax if dropout, to reduce compilation time.
|
// Will only return softmax if dropout, to reduce compilation time.
|
||||||
auto kernel = &flash_fwd_kernel<Kernel_traits, Is_dropout, Is_causal, IsEvenNConst, IsEvenKConst, ReturnSoftmaxConst && Is_dropout>;
|
// If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates.
|
||||||
// auto kernel = &flash_fwd_kernel<Kernel_traits, Is_dropout, Is_causal, IsEvenNConst, true, ReturnSoftmaxConst && Is_dropout>;
|
// If return_softmax, set IsEvenMNConst to false to reduce number of templates
|
||||||
// if (smem_size >= 48 * 1024) {
|
// If head dim > 128, set IsEvenMNConst to false to reduce number of templates
|
||||||
// C10_CUDA_CHECK(cudaFuncSetAttribute(
|
// If Is_local, set Is_causal to false
|
||||||
// kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
|
auto kernel = &flash_fwd_kernel<Kernel_traits, Is_dropout, Is_causal, Is_local && !Is_causal, Has_alibi, IsEvenMNConst && IsEvenKConst && !Is_local && !ReturnSoftmaxConst && Kernel_traits::kHeadDim <= 128, IsEvenKConst, ReturnSoftmaxConst && Is_dropout>;
|
||||||
// }
|
// auto kernel = &flash_fwd_kernel<Kernel_traits, false, Is_causal, false, false, true, true, false>;
|
||||||
int ctas_per_sm;
|
// printf("IsEvenMNConst = %d, IsEvenKConst = %d, Is_local = %d, Is_causal = %d, ReturnSoftmaxConst = %d, Is_dropout = %d\n", int(IsEvenMNConst), int(IsEvenKConst), int(Is_local), int(Is_causal), int(ReturnSoftmaxConst), int(Is_dropout));
|
||||||
cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
|
// auto kernel = &flash_fwd_kernel<Kernel_traits, false, Is_causal, false, true, true, false>;
|
||||||
&ctas_per_sm, kernel, Kernel_traits::kNThreads, smem_size);
|
// int ctas_per_sm;
|
||||||
|
// cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
|
||||||
|
// &ctas_per_sm, kernel, Kernel_traits::kNThreads, smem_size);
|
||||||
// printf("smem_size = %d, CTAs per SM = %d\n", int(smem_size), ctas_per_sm);
|
// printf("smem_size = %d, CTAs per SM = %d\n", int(smem_size), ctas_per_sm);
|
||||||
kernel<<<grid, Kernel_traits::kNThreads, smem_size, stream>>>(params);
|
kernel<<<grid, Kernel_traits::kNThreads, smem_size, stream>>>(params);
|
||||||
// C10_CUDA_KERNEL_LAUNCH_CHECK();
|
});
|
||||||
|
});
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
template<typename T>
|
template<typename T>
|
||||||
void run_mha_fwd_hdim32(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
void run_mha_fwd_hdim32(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||||
constexpr int Headdim = 32;
|
constexpr static int Headdim = 32;
|
||||||
BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
|
BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
|
||||||
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
|
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
|
||||||
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 128, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
|
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 128, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
|
||||||
@ -64,7 +67,7 @@ void run_mha_fwd_hdim32(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
|||||||
|
|
||||||
template<typename T>
|
template<typename T>
|
||||||
void run_mha_fwd_hdim64(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
void run_mha_fwd_hdim64(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||||
constexpr int Headdim = 64;
|
constexpr static int Headdim = 64;
|
||||||
BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
|
BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
|
||||||
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
|
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
|
||||||
if constexpr(!Is_dropout) {
|
if constexpr(!Is_dropout) {
|
||||||
@ -86,7 +89,7 @@ void run_mha_fwd_hdim64(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
|||||||
|
|
||||||
template<typename T>
|
template<typename T>
|
||||||
void run_mha_fwd_hdim96(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
void run_mha_fwd_hdim96(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||||
constexpr int Headdim = 96;
|
constexpr static int Headdim = 96;
|
||||||
// auto dprops = at::cuda::getCurrentDeviceProperties();
|
// auto dprops = at::cuda::getCurrentDeviceProperties();
|
||||||
bool is_sm8x = true; // dprops->major == 8 && dprops->minor > 0;
|
bool is_sm8x = true; // dprops->major == 8 && dprops->minor > 0;
|
||||||
BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
|
BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
|
||||||
@ -112,7 +115,7 @@ void run_mha_fwd_hdim96(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
|||||||
|
|
||||||
template<typename T>
|
template<typename T>
|
||||||
void run_mha_fwd_hdim128(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
void run_mha_fwd_hdim128(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||||
constexpr int Headdim = 128;
|
constexpr static int Headdim = 128;
|
||||||
// auto dprops = at::cuda::getCurrentDeviceProperties();
|
// auto dprops = at::cuda::getCurrentDeviceProperties();
|
||||||
bool is_sm8x = true; // dprops->major == 8 && dprops->minor > 0;
|
bool is_sm8x = true; // dprops->major == 8 && dprops->minor > 0;
|
||||||
BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
|
BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
|
||||||
@ -149,7 +152,7 @@ void run_mha_fwd_hdim128(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
|||||||
|
|
||||||
template<typename T>
|
template<typename T>
|
||||||
void run_mha_fwd_hdim160(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
void run_mha_fwd_hdim160(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||||
constexpr int Headdim = 160;
|
constexpr static int Headdim = 160;
|
||||||
// auto dprops = at::cuda::getCurrentDeviceProperties();
|
// auto dprops = at::cuda::getCurrentDeviceProperties();
|
||||||
bool is_sm8x = true; // dprops->major == 8 && dprops->minor > 0;
|
bool is_sm8x = true; // dprops->major == 8 && dprops->minor > 0;
|
||||||
BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
|
BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
|
||||||
@ -179,7 +182,7 @@ void run_mha_fwd_hdim160(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
|||||||
|
|
||||||
template<typename T>
|
template<typename T>
|
||||||
void run_mha_fwd_hdim192(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
void run_mha_fwd_hdim192(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||||
constexpr int Headdim = 192;
|
constexpr static int Headdim = 192;
|
||||||
BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
|
BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
|
||||||
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
|
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
|
||||||
if constexpr(!Is_dropout) {
|
if constexpr(!Is_dropout) {
|
||||||
@ -198,7 +201,7 @@ void run_mha_fwd_hdim192(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
|||||||
|
|
||||||
template<typename T>
|
template<typename T>
|
||||||
void run_mha_fwd_hdim224(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
void run_mha_fwd_hdim224(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||||
constexpr int Headdim = 224;
|
constexpr static int Headdim = 224;
|
||||||
int device;
|
int device;
|
||||||
cudaGetDevice(&device);
|
cudaGetDevice(&device);
|
||||||
int max_smem_per_block;
|
int max_smem_per_block;
|
||||||
@ -224,7 +227,7 @@ void run_mha_fwd_hdim224(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
|||||||
|
|
||||||
template<typename T>
|
template<typename T>
|
||||||
void run_mha_fwd_hdim256(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
void run_mha_fwd_hdim256(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||||
constexpr int Headdim = 256;
|
constexpr static int Headdim = 256;
|
||||||
int device;
|
int device;
|
||||||
cudaGetDevice(&device);
|
cudaGetDevice(&device);
|
||||||
int max_smem_per_sm, max_smem_per_block;
|
int max_smem_per_sm, max_smem_per_block;
|
||||||
|
@ -91,17 +91,20 @@ struct Flash_fwd_kernel_traits : public Base {
|
|||||||
SmemLayoutAtomQ{},
|
SmemLayoutAtomQ{},
|
||||||
Shape<Int<kBlockN>, Int<kHeadDim>>{}));
|
Shape<Int<kBlockN>, Int<kHeadDim>>{}));
|
||||||
|
|
||||||
using SmemLayoutAtomVtransposed = decltype(
|
|
||||||
composition(Swizzle<kSwizzle, 3, 3>{},
|
|
||||||
// This has to be kBlockN and not 8, otherwise we get wrong results for d=128
|
// This has to be kBlockN and not 8, otherwise we get wrong results for d=128
|
||||||
Layout<Shape<Int<kBlockKSmem>, Int<kBlockN>>,
|
using SmemLayoutAtomVtransposedNoSwizzle = Layout<Shape<Int<kBlockKSmem>, Int<kBlockN>>,
|
||||||
Stride<_1, Int<kBlockKSmem>>>{}));
|
Stride<_1, Int<kBlockKSmem>>>;
|
||||||
|
using SmemLayoutAtomVtransposed = decltype(
|
||||||
|
composition(Swizzle<kSwizzle, 3, 3>{}, SmemLayoutAtomVtransposedNoSwizzle{}));
|
||||||
using SmemLayoutVtransposed = decltype(tile_to_shape(
|
using SmemLayoutVtransposed = decltype(tile_to_shape(
|
||||||
SmemLayoutAtomVtransposed{},
|
SmemLayoutAtomVtransposed{},
|
||||||
Shape<Int<kHeadDim>, Int<kBlockN>>{}));
|
Shape<Int<kHeadDim>, Int<kBlockN>>{}));
|
||||||
// Maybe the VtransposeNoSwizzle just needs to have the right shape
|
// Maybe the VtransposeNoSwizzle just needs to have the right shape
|
||||||
// And the strides don't matter?
|
// And the strides don't matter?
|
||||||
using SmemLayoutVtransposedNoSwizzle = decltype(SmemLayoutVtransposed{}.layout_fn());
|
using SmemLayoutVtransposedNoSwizzle = decltype(tile_to_shape(
|
||||||
|
SmemLayoutAtomVtransposedNoSwizzle{},
|
||||||
|
Shape<Int<kHeadDim>, Int<kBlockN>>{}));
|
||||||
|
// using SmemLayoutVtransposedNoSwizzle = decltype(SmemLayoutVtransposed{}.layout_fn());
|
||||||
|
|
||||||
using SmemLayoutAtomO = decltype(
|
using SmemLayoutAtomO = decltype(
|
||||||
composition(Swizzle<kSwizzle, 3, 3>{},
|
composition(Swizzle<kSwizzle, 3, 3>{},
|
||||||
@ -110,7 +113,8 @@ struct Flash_fwd_kernel_traits : public Base {
|
|||||||
using SmemLayoutO = decltype(tile_to_shape(
|
using SmemLayoutO = decltype(tile_to_shape(
|
||||||
SmemLayoutAtomO{},
|
SmemLayoutAtomO{},
|
||||||
Shape<Int<kBlockM>, Int<kHeadDim>>{}));
|
Shape<Int<kBlockM>, Int<kHeadDim>>{}));
|
||||||
using SmemCopyAtomO = Copy_Atom<DefaultCopy, elem_type>;
|
using SmemCopyAtomO = Copy_Atom<DefaultCopy, Element>;
|
||||||
|
using SmemCopyAtomOaccum = Copy_Atom<DefaultCopy, ElementAccum>;
|
||||||
|
|
||||||
static constexpr int kSmemQCount = size(SmemLayoutQ{});
|
static constexpr int kSmemQCount = size(SmemLayoutQ{});
|
||||||
static constexpr int kSmemKVCount = size(SmemLayoutKV{}) * 2;
|
static constexpr int kSmemKVCount = size(SmemLayoutKV{}) * 2;
|
||||||
@ -138,11 +142,11 @@ struct Flash_fwd_kernel_traits : public Base {
|
|||||||
DefaultCopy
|
DefaultCopy
|
||||||
>;
|
>;
|
||||||
using GmemTiledCopyQKV = decltype(
|
using GmemTiledCopyQKV = decltype(
|
||||||
make_tiled_copy(Copy_Atom<Gmem_copy_struct, elem_type>{},
|
make_tiled_copy(Copy_Atom<Gmem_copy_struct, Element>{},
|
||||||
GmemLayoutAtom{},
|
GmemLayoutAtom{},
|
||||||
Layout<Shape<_1, _8>>{})); // Val layout, 8 vals per read
|
Layout<Shape<_1, _8>>{})); // Val layout, 8 vals per read
|
||||||
using GmemTiledCopyO = decltype(
|
using GmemTiledCopyO = decltype(
|
||||||
make_tiled_copy(Copy_Atom<DefaultCopy, elem_type>{},
|
make_tiled_copy(Copy_Atom<DefaultCopy, Element>{},
|
||||||
GmemLayoutAtom{},
|
GmemLayoutAtom{},
|
||||||
Layout<Shape<_1, _8>>{})); // Val layout, 8 vals per store
|
Layout<Shape<_1, _8>>{})); // Val layout, 8 vals per store
|
||||||
static constexpr int kGmemThreadsPerRowP = kBlockN / kGmemElemsPerLoad;
|
static constexpr int kGmemThreadsPerRowP = kBlockN / kGmemElemsPerLoad;
|
||||||
@ -151,10 +155,30 @@ struct Flash_fwd_kernel_traits : public Base {
|
|||||||
Stride<Int<kGmemThreadsPerRowP>, _1>>;
|
Stride<Int<kGmemThreadsPerRowP>, _1>>;
|
||||||
|
|
||||||
using GmemTiledCopyP = decltype(
|
using GmemTiledCopyP = decltype(
|
||||||
make_tiled_copy(Copy_Atom<DefaultCopy, elem_type>{},
|
make_tiled_copy(Copy_Atom<DefaultCopy, Element>{},
|
||||||
GmemLayoutAtomP{},
|
GmemLayoutAtomP{},
|
||||||
Layout<Shape<_1, _8>>{})); // Val layout, 8 vals per store
|
Layout<Shape<_1, _8>>{})); // Val layout, 8 vals per store
|
||||||
|
|
||||||
|
using GmemLayoutAtomOaccum = std::conditional_t<
|
||||||
|
kBlockKSmem == 32,
|
||||||
|
Layout<Shape <_16, _8>, // Thread layout, 8 threads per row
|
||||||
|
Stride< _8, _1>>,
|
||||||
|
Layout<Shape <_8, _16>, // Thread layout, 16 threads per row
|
||||||
|
Stride< _16, _1>>
|
||||||
|
>;
|
||||||
|
using GmemTiledCopyOaccum = decltype(
|
||||||
|
make_tiled_copy(Copy_Atom<DefaultCopy, ElementAccum>{},
|
||||||
|
GmemLayoutAtomOaccum{},
|
||||||
|
Layout<Shape < _1, _4>>{})); // Val layout, 4 vals per store
|
||||||
|
using GmemLayoutAtomRotcossin = GmemLayoutAtom;
|
||||||
|
using GmemTiledCopyRotcossin = decltype(
|
||||||
|
make_tiled_copy(Copy_Atom<UniversalCopy<uint64_t>, Element>{},
|
||||||
|
GmemLayoutAtomRotcossin{},
|
||||||
|
Layout<Shape < _1, _4>>{})); // Val layout, 4 vals per load
|
||||||
|
using GmemTiledCopyRotcossinCont = decltype(
|
||||||
|
make_tiled_copy(Copy_Atom<DefaultCopy, Element>{},
|
||||||
|
GmemLayoutAtomRotcossin{},
|
||||||
|
Layout<Shape < _1, _8>>{})); // Val layout, 8 vals per load
|
||||||
};
|
};
|
||||||
|
|
||||||
// Is_V_in_regs is an option to reduce smem usage, but will increase register pressue.
|
// Is_V_in_regs is an option to reduce smem usage, but will increase register pressue.
|
||||||
@ -223,16 +247,19 @@ struct Flash_bwd_kernel_traits : public Base {
|
|||||||
SmemLayoutAtomKV{},
|
SmemLayoutAtomKV{},
|
||||||
make_shape(Int<kBlockN>{}, Int<kHeadDim>{})));
|
make_shape(Int<kBlockN>{}, Int<kHeadDim>{})));
|
||||||
|
|
||||||
|
using SmemLayoutAtomKtransposedNoSwizzle = Layout<Shape<Int<kBlockKSmem>, Int<kBlockN>>,
|
||||||
|
Stride<_1, Int<kBlockKSmem>>>;
|
||||||
using SmemLayoutAtomKtransposed = decltype(
|
using SmemLayoutAtomKtransposed = decltype(
|
||||||
composition(Swizzle<kSwizzle, 3, 3>{},
|
composition(Swizzle<kSwizzle, 3, 3>{}, SmemLayoutAtomKtransposedNoSwizzle{}));
|
||||||
Layout<Shape<Int<kBlockKSmem>, Int<kBlockN>>,
|
|
||||||
Stride<_1, Int<kBlockKSmem>>>{}));
|
|
||||||
using SmemLayoutKtransposed = decltype(tile_to_shape(
|
using SmemLayoutKtransposed = decltype(tile_to_shape(
|
||||||
SmemLayoutAtomKtransposed{},
|
SmemLayoutAtomKtransposed{},
|
||||||
make_shape(Int<kHeadDim>{}, Int<kBlockN>{})));
|
make_shape(Int<kHeadDim>{}, Int<kBlockN>{})));
|
||||||
// Maybe the KtransposeNoSwizzle just needs to have the right shape
|
// Maybe the KtransposeNoSwizzle just needs to have the right shape
|
||||||
// And the strides don't matter?
|
// And the strides don't matter?
|
||||||
using SmemLayoutKtransposedNoSwizzle = decltype(SmemLayoutKtransposed{}.layout_fn());
|
using SmemLayoutKtransposedNoSwizzle = decltype(tile_to_shape(
|
||||||
|
SmemLayoutAtomKtransposedNoSwizzle{},
|
||||||
|
make_shape(Int<kHeadDim>{}, Int<kBlockN>{})));
|
||||||
|
// using SmemLayoutKtransposedNoSwizzle = decltype(SmemLayoutKtransposed{}.layout_fn());
|
||||||
|
|
||||||
// TODO: generalize to other values of kBlockN
|
// TODO: generalize to other values of kBlockN
|
||||||
// TODO: what should be the Swizzle here? 3 is faster than 1, and 1 is faster than 2
|
// TODO: what should be the Swizzle here? 3 is faster than 1, and 1 is faster than 2
|
||||||
@ -250,24 +277,30 @@ struct Flash_bwd_kernel_traits : public Base {
|
|||||||
using SmemLayoutPdS = decltype(tile_to_shape(
|
using SmemLayoutPdS = decltype(tile_to_shape(
|
||||||
SmemLayoutAtomPdS{},
|
SmemLayoutAtomPdS{},
|
||||||
make_shape(Int<kBlockM>{}, Int<kBlockN>{})));
|
make_shape(Int<kBlockM>{}, Int<kBlockN>{})));
|
||||||
|
using SmemLayoutAtomPdStransposedNoSwizzle = Layout<Shape<Int<kPBlockN>, Int<kBlockM>>,
|
||||||
|
Stride<_1, Int<kPBlockN>>>;
|
||||||
using SmemLayoutAtomPdStransposed = decltype(
|
using SmemLayoutAtomPdStransposed = decltype(
|
||||||
composition(Swizzle<kSwizzlePdS, 3, 3>{},
|
composition(Swizzle<kSwizzlePdS, 3, 3>{}, SmemLayoutAtomPdStransposedNoSwizzle{}));
|
||||||
Layout<Shape<Int<kPBlockN>, Int<kBlockM>>,
|
|
||||||
Stride<_1, Int<kPBlockN>>>{}));
|
|
||||||
using SmemLayoutPdStransposed = decltype(tile_to_shape(
|
using SmemLayoutPdStransposed = decltype(tile_to_shape(
|
||||||
SmemLayoutAtomPdStransposed{},
|
SmemLayoutAtomPdStransposed{},
|
||||||
make_shape(Int<kBlockN>{}, Int<kBlockM>{})));
|
make_shape(Int<kBlockN>{}, Int<kBlockM>{})));
|
||||||
using SmemLayoutPdStransposedNoSwizzle = decltype(SmemLayoutPdStransposed{}.layout_fn());
|
using SmemLayoutPdStransposedNoSwizzle = decltype(tile_to_shape(
|
||||||
|
SmemLayoutAtomPdStransposedNoSwizzle{},
|
||||||
|
make_shape(Int<kBlockN>{}, Int<kBlockM>{})));
|
||||||
|
// using SmemLayoutPdStransposedNoSwizzle = decltype(SmemLayoutPdStransposed{}.layout_fn());
|
||||||
using SmemCopyAtomPdS = Copy_Atom<DefaultCopy, elem_type>;
|
using SmemCopyAtomPdS = Copy_Atom<DefaultCopy, elem_type>;
|
||||||
|
|
||||||
|
using SmemLayoutAtomQdOtransposedNoSwizzle = Layout<Shape<Int<kBlockKSmem>, Int<kBlockM>>,
|
||||||
|
Stride<_1, Int<kBlockKSmem>>>;
|
||||||
using SmemLayoutAtomQdOtransposed = decltype(
|
using SmemLayoutAtomQdOtransposed = decltype(
|
||||||
composition(Swizzle<kSwizzle, 3, 3>{},
|
composition(Swizzle<kSwizzle, 3, 3>{}, SmemLayoutAtomQdOtransposedNoSwizzle{}));
|
||||||
Layout<Shape<Int<kBlockKSmem>, Int<kBlockM>>,
|
|
||||||
Stride<_1, Int<kBlockKSmem>>>{}));
|
|
||||||
using SmemLayoutQdOtransposed = decltype(tile_to_shape(
|
using SmemLayoutQdOtransposed = decltype(tile_to_shape(
|
||||||
SmemLayoutAtomQdOtransposed{},
|
SmemLayoutAtomQdOtransposed{},
|
||||||
make_shape(Int<kHeadDim>{}, Int<kBlockM>{})));
|
make_shape(Int<kHeadDim>{}, Int<kBlockM>{})));
|
||||||
using SmemLayoutQdOtransposedNoSwizzle = decltype(SmemLayoutQdOtransposed{}.layout_fn());
|
using SmemLayoutQdOtransposedNoSwizzle = decltype(tile_to_shape(
|
||||||
|
SmemLayoutAtomQdOtransposedNoSwizzle{},
|
||||||
|
make_shape(Int<kHeadDim>{}, Int<kBlockM>{})));
|
||||||
|
// using SmemLayoutQdOtransposedNoSwizzle = decltype(SmemLayoutQdOtransposed{}.layout_fn());
|
||||||
|
|
||||||
using SmemLayoutAtomdKV = decltype(
|
using SmemLayoutAtomdKV = decltype(
|
||||||
composition(Swizzle<kSwizzle, 3, 3>{},
|
composition(Swizzle<kSwizzle, 3, 3>{},
|
||||||
@ -292,13 +325,11 @@ struct Flash_bwd_kernel_traits : public Base {
|
|||||||
static constexpr int kSmemdSCount = size(SmemLayoutPdS{});
|
static constexpr int kSmemdSCount = size(SmemLayoutPdS{});
|
||||||
static constexpr int kSmemPCount = size(SmemLayoutPdS{});
|
static constexpr int kSmemPCount = size(SmemLayoutPdS{});
|
||||||
static constexpr int kSmemdQCount = size(SmemLayoutdQ{});
|
static constexpr int kSmemdQCount = size(SmemLayoutdQ{});
|
||||||
static constexpr int kSmemdPsumCount = kBlockM;
|
|
||||||
static constexpr int kSmemQdOSize = kSmemQdOCount * sizeof(Element);
|
static constexpr int kSmemQdOSize = kSmemQdOCount * sizeof(Element);
|
||||||
static constexpr int kSmemKVSize = kSmemKVCount * sizeof(Element);
|
static constexpr int kSmemKVSize = kSmemKVCount * sizeof(Element);
|
||||||
static constexpr int kSmemdSSize = kSmemdSCount * sizeof(Element);
|
static constexpr int kSmemdSSize = kSmemdSCount * sizeof(Element);
|
||||||
static constexpr int kSmemPSize = kSmemPCount * sizeof(Element);
|
static constexpr int kSmemPSize = kSmemPCount * sizeof(Element);
|
||||||
static constexpr int kSmemdQSize = kSmemdQCount * sizeof(Element);
|
static constexpr int kSmemdQSize = kSmemdQCount * sizeof(Element);
|
||||||
static constexpr int kSmemdPsumSize = kSmemdPsumCount * sizeof(ElementAccum);
|
|
||||||
static constexpr int kSmemSize = kSmemQdOSize
|
static constexpr int kSmemSize = kSmemQdOSize
|
||||||
+ (!Is_V_in_regs
|
+ (!Is_V_in_regs
|
||||||
? kSmemKVSize + kSmemdSSize + std::max(kSmemPSize, kSmemdQSize)
|
? kSmemKVSize + kSmemdSSize + std::max(kSmemPSize, kSmemdQSize)
|
||||||
|
159
candle-flash-attn/kernels/kernel_traits_sm90.h
Normal file
159
candle-flash-attn/kernels/kernel_traits_sm90.h
Normal file
@ -0,0 +1,159 @@
|
|||||||
|
/******************************************************************************
|
||||||
|
* Copyright (c) 2023, Tri Dao.
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "cute/algorithm/copy.hpp"
|
||||||
|
|
||||||
|
#include "cutlass/cutlass.h"
|
||||||
|
#include "cutlass/layout/layout.h"
|
||||||
|
#include <cutlass/numeric_types.h>
|
||||||
|
|
||||||
|
using namespace cute;
|
||||||
|
|
||||||
|
template<int kHeadDim_, int kBlockM_, int kBlockN_, int kNWarps_, typename elem_type=cutlass::half_t>
|
||||||
|
struct Flash_kernel_traits_sm90 {
|
||||||
|
|
||||||
|
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
||||||
|
using Element = elem_type;
|
||||||
|
static constexpr bool Has_cp_async = true;
|
||||||
|
#else
|
||||||
|
using Element = cutlass::half_t;
|
||||||
|
static constexpr bool Has_cp_async = false;
|
||||||
|
#endif
|
||||||
|
|
||||||
|
using ElementAccum = float;
|
||||||
|
using index_t = uint32_t;
|
||||||
|
|
||||||
|
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
||||||
|
using MMA_Atom_Arch = std::conditional_t<
|
||||||
|
std::is_same_v<elem_type, cutlass::half_t>,
|
||||||
|
MMA_Atom<SM80_16x8x16_F32F16F16F32_TN>,
|
||||||
|
MMA_Atom<SM80_16x8x16_F32BF16BF16F32_TN>
|
||||||
|
>;
|
||||||
|
using ValLayoutMNK = Layout<Shape<_1, _2, _1>>;
|
||||||
|
#else
|
||||||
|
using MMA_Atom_Arch = MMA_Atom<SM75_16x8x8_F32F16F16F32_TN>;
|
||||||
|
using ValLayoutMNK = Layout<Shape<_1, _2, _2>>;
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 750
|
||||||
|
using SmemCopyAtom = Copy_Atom<SM75_U32x4_LDSM_N, elem_type>;
|
||||||
|
using SmemCopyAtomTransposed = Copy_Atom<SM75_U16x8_LDSM_T, elem_type>;
|
||||||
|
#else
|
||||||
|
using SmemCopyAtom = Copy_Atom<DefaultCopy, elem_type>;
|
||||||
|
using SmemCopyAtomTransposed = Copy_Atom<DefaultCopy, elem_type>;
|
||||||
|
#endif
|
||||||
|
};
|
||||||
|
|
||||||
|
template<int kHeadDim_, int kBlockM_, int kBlockN_, int kNWarps_, bool Is_Q_in_regs_=false, bool Share_Q_K_smem_=false, typename elem_type=cutlass::half_t,
|
||||||
|
typename Base=Flash_kernel_traits_sm90<kHeadDim_, kBlockM_, kBlockN_, kNWarps_, elem_type> >
|
||||||
|
struct Flash_fwd_kernel_traits : public Base {
|
||||||
|
using Element = typename Base::Element;
|
||||||
|
using ElementAccum = typename Base::ElementAccum;
|
||||||
|
using index_t = typename Base::index_t;
|
||||||
|
static constexpr bool Has_cp_async = Base::Has_cp_async;
|
||||||
|
using SmemCopyAtom = typename Base::SmemCopyAtom;
|
||||||
|
using SmemCopyAtomTransposed = typename Base::SmemCopyAtomTransposed;
|
||||||
|
|
||||||
|
static constexpr bool Share_Q_K_smem = Share_Q_K_smem_;
|
||||||
|
static constexpr bool Is_Q_in_regs = Is_Q_in_regs_ || Share_Q_K_smem;
|
||||||
|
|
||||||
|
// The number of threads.
|
||||||
|
static constexpr int kNWarps = kNWarps_;
|
||||||
|
static constexpr int kNThreads = kNWarps * 32;
|
||||||
|
|
||||||
|
static constexpr int kBlockM = kBlockM_;
|
||||||
|
static constexpr int kBlockN = kBlockN_;
|
||||||
|
static constexpr int kHeadDim = kHeadDim_;
|
||||||
|
static_assert(kHeadDim % 32 == 0);
|
||||||
|
static constexpr int kBlockKSmem = kHeadDim % 64 == 0 ? 64 : 32;
|
||||||
|
static constexpr int kBlockKGmem = kHeadDim % 128 == 0 ? 128 : (kHeadDim % 64 == 0 ? 64 : 32);
|
||||||
|
static constexpr int kSwizzle = kBlockKSmem == 32 ? 2 : 3;
|
||||||
|
|
||||||
|
using TiledMma = TiledMMA<
|
||||||
|
typename Base::MMA_Atom_Arch,
|
||||||
|
Layout<Shape<Int<kNWarps>,_1,_1>>, // 4x1x1 or 8x1x1 thread group
|
||||||
|
typename Base::ValLayoutMNK>; // 1x2x1 or 1x2x2 value group for 16x16x16 MMA and LDSM
|
||||||
|
|
||||||
|
using SmemLayoutAtomQ = decltype(
|
||||||
|
composition(Swizzle<kSwizzle, 3, 3>{},
|
||||||
|
// This has to be kBlockKSmem, using kHeadDim gives wrong results for d=128
|
||||||
|
Layout<Shape<_8, Int<kBlockKSmem>>,
|
||||||
|
Stride<Int<kBlockKSmem>, _1>>{}));
|
||||||
|
using SmemLayoutQ = decltype(tile_to_shape(
|
||||||
|
SmemLayoutAtomQ{},
|
||||||
|
Shape<Int<kBlockM>, Int<kHeadDim>>{}));
|
||||||
|
|
||||||
|
using SmemLayoutKV = decltype(tile_to_shape(
|
||||||
|
SmemLayoutAtomQ{},
|
||||||
|
Shape<Int<kBlockN>, Int<kHeadDim>>{}));
|
||||||
|
|
||||||
|
using SmemLayoutAtomVtransposed = decltype(
|
||||||
|
composition(Swizzle<kSwizzle, 3, 3>{},
|
||||||
|
// This has to be kBlockN and not 8, otherwise we get wrong results for d=128
|
||||||
|
Layout<Shape<Int<kBlockKSmem>, Int<kBlockN>>,
|
||||||
|
Stride<_1, Int<kBlockKSmem>>>{}));
|
||||||
|
using SmemLayoutVtransposed = decltype(tile_to_shape(
|
||||||
|
SmemLayoutAtomVtransposed{},
|
||||||
|
Shape<Int<kHeadDim>, Int<kBlockN>>{}));
|
||||||
|
// Maybe the VtransposeNoSwizzle just needs to have the right shape
|
||||||
|
// And the strides don't matter?
|
||||||
|
using SmemLayoutVtransposedNoSwizzle = decltype(SmemLayoutVtransposed{}.layout_fn());
|
||||||
|
|
||||||
|
using SmemLayoutAtomO = decltype(
|
||||||
|
composition(Swizzle<kSwizzle, 3, 3>{},
|
||||||
|
Layout<Shape<Int<8>, Int<kBlockKSmem>>,
|
||||||
|
Stride<Int<kBlockKSmem>, _1>>{}));
|
||||||
|
using SmemLayoutO = decltype(tile_to_shape(
|
||||||
|
SmemLayoutAtomO{},
|
||||||
|
Shape<Int<kBlockM>, Int<kHeadDim>>{}));
|
||||||
|
using SmemCopyAtomO = Copy_Atom<DefaultCopy, elem_type>;
|
||||||
|
|
||||||
|
static constexpr int kSmemQCount = size(SmemLayoutQ{});
|
||||||
|
static constexpr int kSmemKVCount = size(SmemLayoutKV{}) * 2;
|
||||||
|
static constexpr int kSmemQSize = kSmemQCount * sizeof(Element);
|
||||||
|
static constexpr int kSmemKVSize = kSmemKVCount * sizeof(Element);
|
||||||
|
static constexpr int kSmemSize = Share_Q_K_smem ? std::max(kSmemQSize, kSmemKVSize) : kSmemQSize + kSmemKVSize;
|
||||||
|
|
||||||
|
static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element);
|
||||||
|
static_assert(kHeadDim % kGmemElemsPerLoad == 0, "kHeadDim must be a multiple of kGmemElemsPerLoad");
|
||||||
|
// Using kBlockKSmem here is 6-10% faster than kBlockKGmem for d=128 because of bank conflicts.
|
||||||
|
// For example, for d=128, smem is split into 2 "pages", each page takes care of columns
|
||||||
|
// 0-63 and 64-127. If we have 16 threads per row for gmem read, when we write to smem,
|
||||||
|
// thread 0 - 7 will write to the first page and thread 8 - 15 will write to the second page,
|
||||||
|
// to the same banks.
|
||||||
|
static constexpr int kGmemThreadsPerRow = kBlockKSmem / kGmemElemsPerLoad;
|
||||||
|
static_assert(kNThreads % kGmemThreadsPerRow == 0, "kNThreads must be a multiple of kGmemThreadsPerRow");
|
||||||
|
using GmemLayoutAtom = Layout<Shape <Int<kNThreads / kGmemThreadsPerRow>, Int<kGmemThreadsPerRow>>,
|
||||||
|
Stride<Int<kGmemThreadsPerRow>, _1>>;
|
||||||
|
|
||||||
|
// We use CACHEGLOBAL instead of CACHEALWAYS for both Q and K/V, since we won't be reading
|
||||||
|
// from the same address by the same threadblock. This is slightly faster.
|
||||||
|
using Gmem_copy_struct = std::conditional_t<
|
||||||
|
Has_cp_async,
|
||||||
|
SM80_CP_ASYNC_CACHEGLOBAL<cute::uint128_t>,
|
||||||
|
DefaultCopy
|
||||||
|
>;
|
||||||
|
using GmemTiledCopyQKV = decltype(
|
||||||
|
make_tiled_copy(Copy_Atom<Gmem_copy_struct, elem_type>{},
|
||||||
|
GmemLayoutAtom{},
|
||||||
|
Layout<Shape<_1, _8>>{})); // Val layout, 8 vals per read
|
||||||
|
using GmemTiledCopyO = decltype(
|
||||||
|
make_tiled_copy(Copy_Atom<DefaultCopy, elem_type>{},
|
||||||
|
GmemLayoutAtom{},
|
||||||
|
Layout<Shape<_1, _8>>{})); // Val layout, 8 vals per store
|
||||||
|
static constexpr int kGmemThreadsPerRowP = kBlockN / kGmemElemsPerLoad;
|
||||||
|
static_assert(kNThreads % kGmemThreadsPerRowP == 0, "kNThreads must be a multiple of kGmemThreadsPerRowP");
|
||||||
|
using GmemLayoutAtomP = Layout<Shape <Int<kNThreads / kGmemThreadsPerRowP>, Int<kGmemThreadsPerRowP>>,
|
||||||
|
Stride<Int<kGmemThreadsPerRowP>, _1>>;
|
||||||
|
|
||||||
|
using GmemTiledCopyP = decltype(
|
||||||
|
make_tiled_copy(Copy_Atom<DefaultCopy, elem_type>{},
|
||||||
|
GmemLayoutAtomP{},
|
||||||
|
Layout<Shape<_1, _8>>{})); // Val layout, 8 vals per store
|
||||||
|
|
||||||
|
};
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
@ -8,8 +8,7 @@
|
|||||||
|
|
||||||
#include <cute/tensor.hpp>
|
#include <cute/tensor.hpp>
|
||||||
|
|
||||||
#include <cutlass/cutlass.h>
|
#include <cutlass/numeric_types.h>
|
||||||
#include <cutlass/array.h>
|
|
||||||
|
|
||||||
#include "philox.cuh"
|
#include "philox.cuh"
|
||||||
#include "utils.h"
|
#include "utils.h"
|
||||||
@ -117,15 +116,18 @@ inline __device__ void max_scale_exp2_sum(Tensor<Engine0, Layout0> &tensor, Tens
|
|||||||
}
|
}
|
||||||
|
|
||||||
template <typename Engine, typename Layout>
|
template <typename Engine, typename Layout>
|
||||||
inline __device__ void apply_mask(Tensor<Engine, Layout> &tensor, const uint32_t max_seqlen_k) {
|
inline __device__ void apply_mask(Tensor<Engine, Layout> &tensor, const int max_seqlen_k,
|
||||||
|
const int col_idx_offset_ = 0) {
|
||||||
// tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N))
|
// tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N))
|
||||||
static_assert(Layout::rank == 2, "Only support 2D Tensor");
|
static_assert(Layout::rank == 2, "Only support 2D Tensor");
|
||||||
const uint32_t lane_id = threadIdx.x % 32;
|
const int lane_id = threadIdx.x % 32;
|
||||||
|
const int col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2;
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int nj = 0; nj < size<1, 1>(tensor); ++nj) {
|
for (int nj = 0; nj < size<1, 1>(tensor); ++nj) {
|
||||||
|
const int col_idx_base = col_idx_offset + nj * 8;
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int j = 0; j < size<1, 0>(tensor); ++j) {
|
for (int j = 0; j < size<1, 0>(tensor); ++j) {
|
||||||
const uint32_t col_idx = nj * 8 + j + (lane_id % 4) * 2;
|
const int col_idx = col_idx_base + j;
|
||||||
if (col_idx >= max_seqlen_k) {
|
if (col_idx >= max_seqlen_k) {
|
||||||
// Without the "make_coord" we get wrong results
|
// Without the "make_coord" we get wrong results
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
@ -137,30 +139,30 @@ inline __device__ void apply_mask(Tensor<Engine, Layout> &tensor, const uint32_t
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename Engine, typename Layout>
|
template <bool HasWSLeft=true, typename Engine, typename Layout>
|
||||||
inline __device__ void apply_mask_causal(Tensor<Engine, Layout> &tensor, const uint32_t col_idx_offset_,
|
inline __device__ void apply_mask_local(Tensor<Engine, Layout> &tensor, const int col_idx_offset_,
|
||||||
const uint32_t max_seqlen_k, const uint32_t row_idx_offset_,
|
const int max_seqlen_k, const int row_idx_offset,
|
||||||
const uint32_t warp_row_stride) {
|
const int max_seqlen_q, const int warp_row_stride,
|
||||||
|
const int window_size_left, const int window_size_right) {
|
||||||
// tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N))
|
// tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N))
|
||||||
static_assert(Layout::rank == 2, "Only support 2D Tensor");
|
static_assert(Layout::rank == 2, "Only support 2D Tensor");
|
||||||
const uint32_t lane_id = threadIdx.x % 32;
|
const int lane_id = threadIdx.x % 32;
|
||||||
// const uint32_t row_idx_offset = row_idx_offset_ + lane_id / 4;
|
const int col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2;
|
||||||
const uint32_t row_idx_offset = row_idx_offset_;
|
|
||||||
const uint32_t col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2;
|
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int mi = 0; mi < size<0, 1>(tensor); ++mi) {
|
for (int mi = 0; mi < size<0, 1>(tensor); ++mi) {
|
||||||
const uint32_t row_idx_base = row_idx_offset + mi * warp_row_stride;
|
const int row_idx_base = row_idx_offset + mi * warp_row_stride;
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i = 0; i < size<0, 0>(tensor); ++i) {
|
for (int i = 0; i < size<0, 0>(tensor); ++i) {
|
||||||
const uint32_t row_idx = row_idx_base + i * 8;
|
const int row_idx = row_idx_base + i * 8;
|
||||||
const uint32_t col_idx_limit = std::min(max_seqlen_k, row_idx + 1);
|
const int col_idx_limit_left = std::max(0, row_idx + max_seqlen_k - max_seqlen_q - window_size_left);
|
||||||
|
const int col_idx_limit_right = std::min(max_seqlen_k, row_idx + 1 + max_seqlen_k - max_seqlen_q + window_size_right);
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int nj = 0; nj < size<1, 1>(tensor); ++nj) {
|
for (int nj = 0; nj < size<1, 1>(tensor); ++nj) {
|
||||||
const uint32_t col_idx_base = col_idx_offset + nj * 8;
|
const int col_idx_base = col_idx_offset + nj * 8;
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int j = 0; j < size<1, 0>(tensor); ++j) {
|
for (int j = 0; j < size<1, 0>(tensor); ++j) {
|
||||||
const uint32_t col_idx = col_idx_base + j;
|
const int col_idx = col_idx_base + j;
|
||||||
if (col_idx >= col_idx_limit) {
|
if (col_idx >= col_idx_limit_right || (HasWSLeft && col_idx < col_idx_limit_left)) {
|
||||||
tensor(make_coord(i, mi), make_coord(j, nj)) = -INFINITY;
|
tensor(make_coord(i, mi), make_coord(j, nj)) = -INFINITY;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -174,10 +176,19 @@ inline __device__ void apply_mask_causal(Tensor<Engine, Layout> &tensor, const u
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename Engine, typename Layout>
|
||||||
|
inline __device__ void apply_mask_causal(Tensor<Engine, Layout> &tensor, const int col_idx_offset_,
|
||||||
|
const int max_seqlen_k, const int row_idx_offset,
|
||||||
|
const int max_seqlen_q, const int warp_row_stride) {
|
||||||
|
// Causal masking is equivalent to local masking with window_size_left = infinity and window_size_right = 0
|
||||||
|
apply_mask_local</*HasWSLeft=*/false>(tensor, col_idx_offset_, max_seqlen_k, row_idx_offset,
|
||||||
|
max_seqlen_q, warp_row_stride, -1, 0);
|
||||||
|
}
|
||||||
|
|
||||||
template <typename Engine0, typename Layout0, typename Engine1, typename Layout1>
|
template <typename Engine0, typename Layout0, typename Engine1, typename Layout1>
|
||||||
inline __device__ void apply_mask_causal_w_idx(
|
inline __device__ void apply_mask_causal_w_idx(
|
||||||
Tensor<Engine0, Layout0> &tensor, Tensor<Engine1, Layout1> const &idx_rowcol,
|
Tensor<Engine0, Layout0> &tensor, Tensor<Engine1, Layout1> const &idx_rowcol,
|
||||||
const uint32_t col_idx_offset_, const uint32_t max_seqlen_k, const uint32_t row_idx_offset_)
|
const int col_idx_offset_, const int max_seqlen_k, const int row_idx_offset)
|
||||||
{
|
{
|
||||||
// tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N))
|
// tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N))
|
||||||
static_assert(Layout0::rank == 2, "Only support 2D Tensor");
|
static_assert(Layout0::rank == 2, "Only support 2D Tensor");
|
||||||
@ -186,7 +197,7 @@ inline __device__ void apply_mask_causal_w_idx(
|
|||||||
CUTE_STATIC_ASSERT_V(size<1>(tensor) == size<1>(idx_rowcol));
|
CUTE_STATIC_ASSERT_V(size<1>(tensor) == size<1>(idx_rowcol));
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int mi = 0; mi < size<0>(tensor); ++mi) {
|
for (int mi = 0; mi < size<0>(tensor); ++mi) {
|
||||||
const uint32_t col_idx_limit = std::min(max_seqlen_k, 1 + row_idx_offset_ + get<0>(idx_rowcol(mi, 0)));
|
const int col_idx_limit = std::min(max_seqlen_k, 1 + row_idx_offset + get<0>(idx_rowcol(mi, 0)));
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int ni = 0; ni < size<1, 1>(tensor); ++ni) {
|
for (int ni = 0; ni < size<1, 1>(tensor); ++ni) {
|
||||||
if (col_idx_offset_ + get<1>(idx_rowcol(0, ni)) >= col_idx_limit) {
|
if (col_idx_offset_ + get<1>(idx_rowcol(0, ni)) >= col_idx_limit) {
|
||||||
@ -204,8 +215,8 @@ inline __device__ void apply_mask_causal_w_idx(
|
|||||||
template <bool encode_dropout_in_sign_bit=false, typename Engine, typename Layout>
|
template <bool encode_dropout_in_sign_bit=false, typename Engine, typename Layout>
|
||||||
inline __device__ void apply_dropout(Tensor<Engine, Layout> &tensor, uint8_t p_dropout_in_uint8_t,
|
inline __device__ void apply_dropout(Tensor<Engine, Layout> &tensor, uint8_t p_dropout_in_uint8_t,
|
||||||
unsigned long long seed, unsigned long long offset,
|
unsigned long long seed, unsigned long long offset,
|
||||||
uint32_t block_row_start, uint32_t block_col_start,
|
int block_row_start, int block_col_start,
|
||||||
uint32_t block_row_stride) {
|
int block_row_stride) {
|
||||||
// tensor has shape (8, MMA_M, MMA_N / 2)
|
// tensor has shape (8, MMA_M, MMA_N / 2)
|
||||||
using T = typename Engine::value_type;
|
using T = typename Engine::value_type;
|
||||||
auto encode_dropout = [](bool keep, T val) {
|
auto encode_dropout = [](bool keep, T val) {
|
||||||
|
@ -87,46 +87,6 @@ inline __device__ uint32_t convert_relu2<cutlass::bfloat16_t>(const float2 x) {
|
|||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
template<typename T>
|
|
||||||
inline __device__ float2 half2_unpack(uint32_t a);
|
|
||||||
|
|
||||||
template <>
|
|
||||||
inline __device__ float2 half2_unpack<__half>(uint32_t a) {
|
|
||||||
return __half22float2(reinterpret_cast<__half2 (&)>(a));
|
|
||||||
}
|
|
||||||
|
|
||||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
|
||||||
template <>
|
|
||||||
inline __device__ float2 half2_unpack<__nv_bfloat16>(uint32_t a) {
|
|
||||||
return __bfloat1622float2(reinterpret_cast<__nv_bfloat162 (&)>(a));
|
|
||||||
}
|
|
||||||
#endif
|
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
||||||
|
|
||||||
// Convert two half2's or bf162's into float, then take their dot product.
|
|
||||||
template <typename T>
|
|
||||||
inline __device__ float hfma2_to_float(const uint32_t a, const uint32_t b) {
|
|
||||||
float2 af = flash::half2_unpack<T>(a);
|
|
||||||
float2 bf = flash::half2_unpack<T>(b);
|
|
||||||
return af.x * bf.x + af.y * bf.y;
|
|
||||||
}
|
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
||||||
|
|
||||||
// Converted two vectors of 8 half's or bf16's into float, then take their dot product.
|
|
||||||
template<typename T>
|
|
||||||
inline __device__ float hmulsum8(const uint4 a, const uint4 b) {
|
|
||||||
float sum;
|
|
||||||
sum = flash::hfma2_to_float<T>(a.x, b.x);
|
|
||||||
sum += flash::hfma2_to_float<T>(a.y, b.y);
|
|
||||||
sum += flash::hfma2_to_float<T>(a.z, b.z);
|
|
||||||
sum += flash::hfma2_to_float<T>(a.w, b.w);
|
|
||||||
return sum;
|
|
||||||
}
|
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
||||||
|
|
||||||
template<typename T>
|
template<typename T>
|
||||||
struct MaxOp {
|
struct MaxOp {
|
||||||
__device__ inline T operator()(T const & x, T const & y) { return x > y ? x : y; }
|
__device__ inline T operator()(T const & x, T const & y) { return x > y ? x : y; }
|
||||||
@ -173,10 +133,12 @@ static __device__ inline T run(T x, Operator &op) {
|
|||||||
|
|
||||||
template<bool A_in_regs=false, bool B_in_regs=false, typename Tensor0, typename Tensor1,
|
template<bool A_in_regs=false, bool B_in_regs=false, typename Tensor0, typename Tensor1,
|
||||||
typename Tensor2, typename Tensor3, typename Tensor4,
|
typename Tensor2, typename Tensor3, typename Tensor4,
|
||||||
typename TiledMma, typename TiledCopy0, typename TiledCopy1>
|
typename TiledMma, typename TiledCopyA, typename TiledCopyB,
|
||||||
|
typename ThrCopyA, typename ThrCopyB>
|
||||||
inline __device__ void gemm(Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCrB, Tensor3 const& tCsA,
|
inline __device__ void gemm(Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCrB, Tensor3 const& tCsA,
|
||||||
Tensor4 const& tCsB, TiledMma tiled_mma,
|
Tensor4 const& tCsB, TiledMma tiled_mma,
|
||||||
TiledCopy0 smem_thr_copy_A, TiledCopy1 smem_thr_copy_B) {
|
TiledCopyA smem_tiled_copy_A, TiledCopyB smem_tiled_copy_B,
|
||||||
|
ThrCopyA smem_thr_copy_A, ThrCopyB smem_thr_copy_B) {
|
||||||
CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(acc)); // MMA_M
|
CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(acc)); // MMA_M
|
||||||
CUTE_STATIC_ASSERT_V(size<1>(tCrB) == size<2>(acc)); // MMA_N
|
CUTE_STATIC_ASSERT_V(size<1>(tCrB) == size<2>(acc)); // MMA_N
|
||||||
CUTE_STATIC_ASSERT_V(size<2>(tCrA) == size<2>(tCrB)); // MMA_K
|
CUTE_STATIC_ASSERT_V(size<2>(tCrA) == size<2>(tCrB)); // MMA_K
|
||||||
@ -184,13 +146,13 @@ inline __device__ void gemm(Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCrB, Tensor3
|
|||||||
CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(tCrA_copy_view)); // M
|
CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(tCrA_copy_view)); // M
|
||||||
Tensor tCrB_copy_view = smem_thr_copy_B.retile_D(tCrB);
|
Tensor tCrB_copy_view = smem_thr_copy_B.retile_D(tCrB);
|
||||||
CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<1>(tCrB_copy_view)); // N
|
CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<1>(tCrB_copy_view)); // N
|
||||||
if (!A_in_regs) { copy(smem_thr_copy_A, tCsA(_, _, _0{}), tCrA_copy_view(_, _, _0{})); }
|
if (!A_in_regs) { cute::copy(smem_tiled_copy_A, tCsA(_, _, _0{}), tCrA_copy_view(_, _, _0{})); }
|
||||||
if (!B_in_regs) { copy(smem_thr_copy_B, tCsB(_, _, _0{}), tCrB_copy_view(_, _, _0{})); }
|
if (!B_in_regs) { cute::copy(smem_tiled_copy_B, tCsB(_, _, _0{}), tCrB_copy_view(_, _, _0{})); }
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i = 0; i < size<2>(tCrA); ++i) {
|
for (int i = 0; i < size<2>(tCrA); ++i) {
|
||||||
if (i < size<2>(tCrA) - 1) {
|
if (i < size<2>(tCrA) - 1) {
|
||||||
if (!A_in_regs) { copy(smem_thr_copy_A, tCsA(_, _, i + 1), tCrA_copy_view(_, _, i + 1)); }
|
if (!A_in_regs) { cute::copy(smem_tiled_copy_A, tCsA(_, _, i + 1), tCrA_copy_view(_, _, i + 1)); }
|
||||||
if (!B_in_regs) { copy(smem_thr_copy_B, tCsB(_, _, i + 1), tCrB_copy_view(_, _, i + 1)); }
|
if (!B_in_regs) { cute::copy(smem_tiled_copy_B, tCsB(_, _, i + 1), tCrB_copy_view(_, _, i + 1)); }
|
||||||
}
|
}
|
||||||
cute::gemm(tiled_mma, tCrA(_, _, i), tCrB(_, _, i), acc);
|
cute::gemm(tiled_mma, tCrA(_, _, i), tCrB(_, _, i), acc);
|
||||||
}
|
}
|
||||||
@ -199,19 +161,20 @@ inline __device__ void gemm(Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCrB, Tensor3
|
|||||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
template<typename Tensor0, typename Tensor1, typename Tensor2, typename Tensor3,
|
template<typename Tensor0, typename Tensor1, typename Tensor2, typename Tensor3,
|
||||||
typename TiledMma, typename TiledCopy>
|
typename TiledMma, typename TiledCopy, typename ThrCopy>
|
||||||
inline __device__ void gemm_A_in_regs(Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCrB, Tensor3 const& tCsB,
|
inline __device__ void gemm_A_in_regs(Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCrB, Tensor3 const& tCsB,
|
||||||
TiledMma tiled_mma, TiledCopy smem_thr_copy_B) {
|
TiledMma tiled_mma, TiledCopy smem_tiled_copy_B,
|
||||||
|
ThrCopy smem_thr_copy_B) {
|
||||||
CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(acc)); // MMA_M
|
CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(acc)); // MMA_M
|
||||||
CUTE_STATIC_ASSERT_V(size<1>(tCrB) == size<2>(acc)); // MMA_N
|
CUTE_STATIC_ASSERT_V(size<1>(tCrB) == size<2>(acc)); // MMA_N
|
||||||
CUTE_STATIC_ASSERT_V(size<2>(tCrA) == size<2>(tCrB)); // MMA_K
|
CUTE_STATIC_ASSERT_V(size<2>(tCrA) == size<2>(tCrB)); // MMA_K
|
||||||
Tensor tCrB_copy_view = smem_thr_copy_B.retile_D(tCrB);
|
Tensor tCrB_copy_view = smem_thr_copy_B.retile_D(tCrB);
|
||||||
CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<1>(tCrB_copy_view)); // N
|
CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<1>(tCrB_copy_view)); // N
|
||||||
copy(smem_thr_copy_B, tCsB(_, _, _0{}), tCrB_copy_view(_, _, _0{}));
|
cute::copy(smem_tiled_copy_B, tCsB(_, _, _0{}), tCrB_copy_view(_, _, _0{}));
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i = 0; i < size<2>(tCrA); ++i) {
|
for (int i = 0; i < size<2>(tCrA); ++i) {
|
||||||
if (i < size<2>(tCrA) - 1) {
|
if (i < size<2>(tCrA) - 1) {
|
||||||
copy(smem_thr_copy_B, tCsB(_, _, i + 1), tCrB_copy_view(_, _, i + 1));
|
cute::copy(smem_tiled_copy_B, tCsB(_, _, i + 1), tCrB_copy_view(_, _, i + 1));
|
||||||
}
|
}
|
||||||
cute::gemm(tiled_mma, tCrA(_, _, i), tCrB(_, _, i), acc);
|
cute::gemm(tiled_mma, tCrA(_, _, i), tCrB(_, _, i), acc);
|
||||||
}
|
}
|
||||||
@ -225,7 +188,10 @@ inline __device__ auto convert_layout_acc_rowcol(Layout acc_layout) {
|
|||||||
static_assert(decltype(size<0>(acc_layout))::value == 4);
|
static_assert(decltype(size<0>(acc_layout))::value == 4);
|
||||||
static_assert(decltype(rank(acc_layout))::value == 3);
|
static_assert(decltype(rank(acc_layout))::value == 3);
|
||||||
auto l = logical_divide(acc_layout, Shape<_2>{}); // ((2, 2), MMA_M, MMA_N)
|
auto l = logical_divide(acc_layout, Shape<_2>{}); // ((2, 2), MMA_M, MMA_N)
|
||||||
return make_layout(make_layout(get<0, 1>(l), get<1>(l)), make_layout(get<0, 0>(l), get<2>(l)));
|
// TD [2023-08-13]: Idk why but get<0, 1>(l) doesn't work for Cutlass 3.2, I'm getting
|
||||||
|
// "int_tuple.hpp(74): error: conversion to inaccessible base class"
|
||||||
|
// return make_layout(make_layout(get<0, 1>(l), get<1>(l)), make_layout(get<0, 0>(l), get<2>(l)));
|
||||||
|
return make_layout(make_layout(get<1>(get<0>(l)), get<1>(l)), make_layout(get<0>(get<0>(l)), get<2>(l)));
|
||||||
};
|
};
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
@ -241,9 +207,13 @@ inline __device__ auto convert_layout_rowcol_Aregs(Layout rowcol_layout) {
|
|||||||
static_assert(mma_shape_K == 8 || mma_shape_K == 16);
|
static_assert(mma_shape_K == 8 || mma_shape_K == 16);
|
||||||
constexpr int MMA_N_divisor = mma_shape_K == 8 ? 1 : 2;
|
constexpr int MMA_N_divisor = mma_shape_K == 8 ? 1 : 2;
|
||||||
auto l = logical_divide(rowcol_layout, Shape<X, Shape<X, Int<MMA_N_divisor>>>{}); // ((2, MMA_M), (2, (2, MMA_N / 2)))
|
auto l = logical_divide(rowcol_layout, Shape<X, Shape<X, Int<MMA_N_divisor>>>{}); // ((2, MMA_M), (2, (2, MMA_N / 2)))
|
||||||
return make_layout(make_layout(get<1, 0>(l), get<0, 0>(l), get<1, 1, 0>(l)),
|
// TD [2023-08-13]: Same error as above on Cutlass 3.2
|
||||||
get<0, 1>(l),
|
// return make_layout(make_layout(get<1, 0>(l), get<0, 0>(l), get<1, 1, 0>(l)),
|
||||||
get<1, 1, 1>(l));
|
// get<0, 1>(l),
|
||||||
|
// get<1, 1, 1>(l));
|
||||||
|
return make_layout(make_layout(get<0>(get<1>(l)), get<0>(get<0>(l)), get<0>(get<1>(get<1>(l)))),
|
||||||
|
get<1>(get<0>(l)),
|
||||||
|
get<1>(get<1>(get<1>(l))));
|
||||||
};
|
};
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
@ -319,9 +289,9 @@ void cp_async_wait() {
|
|||||||
template <bool Is_even_MN=true, bool Is_even_K=true, bool Clear_OOB_MN=false, bool Clear_OOB_K=true,
|
template <bool Is_even_MN=true, bool Is_even_K=true, bool Clear_OOB_MN=false, bool Clear_OOB_K=true,
|
||||||
typename TiledCopy, typename Engine0, typename Layout0, typename Engine1, typename Layout1,
|
typename TiledCopy, typename Engine0, typename Layout0, typename Engine1, typename Layout1,
|
||||||
typename Engine2, typename Layout2, typename Engine3, typename Layout3>
|
typename Engine2, typename Layout2, typename Engine3, typename Layout3>
|
||||||
inline __device__ void copy(TiledCopy thr_copy, Tensor<Engine0, Layout0> const &S,
|
inline __device__ void copy(TiledCopy tiled_copy, Tensor<Engine0, Layout0> const &S,
|
||||||
Tensor<Engine1, Layout1> &D, Tensor<Engine2, Layout2> const &identity_MN,
|
Tensor<Engine1, Layout1> &D, Tensor<Engine2, Layout2> const &identity_MN,
|
||||||
Tensor<Engine3, Layout3> const &predicate_K, int max_MN=0) {
|
Tensor<Engine3, Layout3> const &predicate_K, const int max_MN=0) {
|
||||||
CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{});
|
CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{});
|
||||||
CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{});
|
CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{});
|
||||||
CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D)); // MMA
|
CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D)); // MMA
|
||||||
@ -335,13 +305,13 @@ inline __device__ void copy(TiledCopy thr_copy, Tensor<Engine0, Layout0> const &
|
|||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int k = 0; k < size<2>(S); ++k) {
|
for (int k = 0; k < size<2>(S); ++k) {
|
||||||
if (Is_even_K || predicate_K(k)) {
|
if (Is_even_K || predicate_K(k)) {
|
||||||
copy(thr_copy, S(_, m, k), D(_, m, k));
|
cute::copy(tiled_copy, S(_, m, k), D(_, m, k));
|
||||||
} else if (Clear_OOB_K) {
|
} else if (Clear_OOB_K) {
|
||||||
clear(D(_, m, k));
|
cute::clear(D(_, m, k));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else if (Clear_OOB_MN) {
|
} else if (Clear_OOB_MN) {
|
||||||
clear(D(_, m, _));
|
cute::clear(D(_, m, _));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// TD [2023-04-13]: Strange that the code below can cause race condition.
|
// TD [2023-04-13]: Strange that the code below can cause race condition.
|
||||||
@ -350,7 +320,7 @@ inline __device__ void copy(TiledCopy thr_copy, Tensor<Engine0, Layout0> const &
|
|||||||
// #pragma unroll
|
// #pragma unroll
|
||||||
// for (int m = 0; m < size<1>(S); ++m) {
|
// for (int m = 0; m < size<1>(S); ++m) {
|
||||||
// if (Is_even_MN || get<0>(identity_MN(0, m, 0)) < max_MN) {
|
// if (Is_even_MN || get<0>(identity_MN(0, m, 0)) < max_MN) {
|
||||||
// copy(thr_copy, S(_, m, _), D(_, m, _));
|
// copy(tiled_copy, S(_, m, _), D(_, m, _));
|
||||||
// } else if (Clear_OOB_MN) {
|
// } else if (Clear_OOB_MN) {
|
||||||
// clear(D(_, m, _));
|
// clear(D(_, m, _));
|
||||||
// }
|
// }
|
||||||
@ -362,7 +332,7 @@ inline __device__ void copy(TiledCopy thr_copy, Tensor<Engine0, Layout0> const &
|
|||||||
// #pragma unroll
|
// #pragma unroll
|
||||||
// for (int m = 0; m < size<1>(S); ++m) {
|
// for (int m = 0; m < size<1>(S); ++m) {
|
||||||
// if (Is_even_MN || get<0>(identity_MN(0, m, 0)) < max_MN) {
|
// if (Is_even_MN || get<0>(identity_MN(0, m, 0)) < max_MN) {
|
||||||
// copy(thr_copy, S(_, m, k), D(_, m, k));
|
// copy(tiled_copy, S(_, m, k), D(_, m, k));
|
||||||
// } else if (Clear_OOB_MN) {
|
// } else if (Clear_OOB_MN) {
|
||||||
// clear(D(_, m, k));
|
// clear(D(_, m, k));
|
||||||
// }
|
// }
|
||||||
|
@ -7,6 +7,8 @@ extern "C" {
|
|||||||
v_ptr: *const c_void,
|
v_ptr: *const c_void,
|
||||||
o_ptr: *const c_void,
|
o_ptr: *const c_void,
|
||||||
softmax_lse_ptr: *const c_void,
|
softmax_lse_ptr: *const c_void,
|
||||||
|
alibi_slopes_ptr: *const c_void,
|
||||||
|
|
||||||
cu_seqlens_q_ptr: *const i32,
|
cu_seqlens_q_ptr: *const i32,
|
||||||
cu_seqlens_k_ptr: *const i32,
|
cu_seqlens_k_ptr: *const i32,
|
||||||
|
|
||||||
@ -14,6 +16,7 @@ extern "C" {
|
|||||||
k_batch_stride: u32,
|
k_batch_stride: u32,
|
||||||
v_batch_stride: u32,
|
v_batch_stride: u32,
|
||||||
o_batch_stride: u32,
|
o_batch_stride: u32,
|
||||||
|
alibi_slopes_batch_stride: u32,
|
||||||
|
|
||||||
q_row_stride: u32,
|
q_row_stride: u32,
|
||||||
k_row_stride: u32,
|
k_row_stride: u32,
|
||||||
@ -37,8 +40,11 @@ extern "C" {
|
|||||||
seqlen_q_rounded: u32,
|
seqlen_q_rounded: u32,
|
||||||
seqlen_k_rounded: u32,
|
seqlen_k_rounded: u32,
|
||||||
|
|
||||||
is_causal: c_int,
|
|
||||||
is_bf16: c_int,
|
is_bf16: c_int,
|
||||||
|
is_causal: c_int,
|
||||||
|
|
||||||
|
window_size_left: c_int,
|
||||||
|
window_size_right: c_int,
|
||||||
);
|
);
|
||||||
|
|
||||||
}
|
}
|
||||||
|
@ -3,12 +3,14 @@ mod ffi;
|
|||||||
use candle::backend::BackendStorage;
|
use candle::backend::BackendStorage;
|
||||||
use candle::cuda_backend::cudarc::driver::DevicePtr;
|
use candle::cuda_backend::cudarc::driver::DevicePtr;
|
||||||
use candle::cuda_backend::WrapErr;
|
use candle::cuda_backend::WrapErr;
|
||||||
use candle::{CpuStorage, Layout, Result, Shape, Tensor};
|
use candle::{CpuStorage, DType, Layout, Result, Shape, Tensor};
|
||||||
use half::{bf16, f16};
|
use half::{bf16, f16};
|
||||||
|
|
||||||
pub struct FlashAttn {
|
pub struct FlashAttn {
|
||||||
pub softmax_scale: f32,
|
pub softmax_scale: f32,
|
||||||
pub causal: bool,
|
pub alibi_slopes: Option<Tensor>,
|
||||||
|
pub window_size_left: Option<usize>,
|
||||||
|
pub window_size_right: Option<usize>,
|
||||||
}
|
}
|
||||||
|
|
||||||
fn round_multiple(x: usize, m: usize) -> usize {
|
fn round_multiple(x: usize, m: usize) -> usize {
|
||||||
@ -85,6 +87,51 @@ impl FlashAttn {
|
|||||||
candle::bail!("number of k/v heads {num_heads_k} must divide number of heads in query {num_heads}")
|
candle::bail!("number of k/v heads {num_heads_k} must divide number of heads in query {num_heads}")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
let alibi_slopes_ptr = if let Some(alibi_slopes) = &self.alibi_slopes {
|
||||||
|
if alibi_slopes.dtype() != DType::F32 {
|
||||||
|
candle::bail!(
|
||||||
|
"DType mismatch alibi_slopes {:?}, expected {:?}",
|
||||||
|
alibi_slopes.dtype(),
|
||||||
|
DType::F32
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
let (alibi_slopes, alibi_slopes_layout) = alibi_slopes.storage_and_layout();
|
||||||
|
|
||||||
|
if num_heads != alibi_slopes_layout.shape().dims1()? {
|
||||||
|
candle::bail!(
|
||||||
|
"shape mismatch alibi_slopes {:?}, expected {:?}",
|
||||||
|
alibi_slopes_layout.shape(),
|
||||||
|
(num_heads)
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
let alibi_slopes = match &*alibi_slopes {
|
||||||
|
candle::Storage::Cuda(c) => c.as_cuda_slice::<f32>()?,
|
||||||
|
_ => candle::bail!("alibi_slopes must be a cuda tensor"),
|
||||||
|
};
|
||||||
|
|
||||||
|
let alibi_slopes = alibi_slopes.slice(alibi_slopes_layout.start_offset()..);
|
||||||
|
|
||||||
|
*alibi_slopes.device_ptr() as *const core::ffi::c_void
|
||||||
|
} else {
|
||||||
|
std::ptr::null()
|
||||||
|
};
|
||||||
|
|
||||||
|
// if window_size_left > self.max_seqlen_k or None => -1
|
||||||
|
let mut window_size_left = self
|
||||||
|
.window_size_left
|
||||||
|
.filter(|v| v <= &seqlen_k)
|
||||||
|
.map(|v| v as i32)
|
||||||
|
.unwrap_or(-1);
|
||||||
|
|
||||||
|
// if window_size_right > self.max_seqlen_k or None => -1
|
||||||
|
let mut window_size_right = self
|
||||||
|
.window_size_right
|
||||||
|
.filter(|v| v <= &seqlen_k)
|
||||||
|
.map(|v| v as i32)
|
||||||
|
.unwrap_or(-1);
|
||||||
|
|
||||||
let head_size = round_multiple(head_size_og, 8);
|
let head_size = round_multiple(head_size_og, 8);
|
||||||
let head_size_rounded = round_multiple(head_size, 32);
|
let head_size_rounded = round_multiple(head_size, 32);
|
||||||
let seqlen_q_rounded = round_multiple(seqlen_q, 128);
|
let seqlen_q_rounded = round_multiple(seqlen_q, 128);
|
||||||
@ -94,9 +141,22 @@ impl FlashAttn {
|
|||||||
let dst = unsafe { dev.alloc::<T>(elem_count) }.w()?;
|
let dst = unsafe { dev.alloc::<T>(elem_count) }.w()?;
|
||||||
let softmax_lse = dev.alloc_zeros::<f32>(b_sz * num_heads * seqlen_q).w()?;
|
let softmax_lse = dev.alloc_zeros::<f32>(b_sz * num_heads * seqlen_q).w()?;
|
||||||
|
|
||||||
let causal = if self.causal { 1 } else { 0 };
|
|
||||||
let is_bf16 = if is_bf16 { 1 } else { 0 };
|
let is_bf16 = if is_bf16 { 1 } else { 0 };
|
||||||
|
|
||||||
|
// Causal is the special case where window_size_right == 0 and window_size_left < 0.
|
||||||
|
// Local is the more general case where window_size_right >= 0 or window_size_left >= 0.
|
||||||
|
let is_causal = if window_size_left < 0 && window_size_right == 0 {
|
||||||
|
1
|
||||||
|
} else {
|
||||||
|
0
|
||||||
|
};
|
||||||
|
if window_size_left < 0 && window_size_right >= 0 {
|
||||||
|
window_size_left = seqlen_k as i32;
|
||||||
|
}
|
||||||
|
if window_size_left >= 0 && window_size_right < 0 {
|
||||||
|
window_size_right = seqlen_k as i32;
|
||||||
|
}
|
||||||
|
|
||||||
unsafe {
|
unsafe {
|
||||||
let q_ptr = *q.device_ptr() as *const core::ffi::c_void;
|
let q_ptr = *q.device_ptr() as *const core::ffi::c_void;
|
||||||
let k_ptr = *k.device_ptr() as *const core::ffi::c_void;
|
let k_ptr = *k.device_ptr() as *const core::ffi::c_void;
|
||||||
@ -109,12 +169,14 @@ impl FlashAttn {
|
|||||||
v_ptr,
|
v_ptr,
|
||||||
dst_ptr,
|
dst_ptr,
|
||||||
softmax_lse_ptr,
|
softmax_lse_ptr,
|
||||||
|
/* alibi_slopes_ptr */ alibi_slopes_ptr,
|
||||||
/* cu_seqlens_q_ptr */ std::ptr::null(),
|
/* cu_seqlens_q_ptr */ std::ptr::null(),
|
||||||
/* cu_seqlens_k_ptr */ std::ptr::null(),
|
/* cu_seqlens_k_ptr */ std::ptr::null(),
|
||||||
/* q_batch_stride */ q_stride[0] as u32,
|
/* q_batch_stride */ q_stride[0] as u32,
|
||||||
/* k_batch_stride */ k_stride[0] as u32,
|
/* k_batch_stride */ k_stride[0] as u32,
|
||||||
/* v_batch_stride */ v_stride[0] as u32,
|
/* v_batch_stride */ v_stride[0] as u32,
|
||||||
/* o_batch_stride */ o_stride[0] as u32,
|
/* o_batch_stride */ o_stride[0] as u32,
|
||||||
|
/* alibi_slopes_batch_stride */ 0,
|
||||||
/* q_row_stride */ q_stride[q_rank - 3] as u32,
|
/* q_row_stride */ q_stride[q_rank - 3] as u32,
|
||||||
/* k_row_stride */ k_stride[k_rank - 3] as u32,
|
/* k_row_stride */ k_stride[k_rank - 3] as u32,
|
||||||
/* v_row_stride */ v_stride[v_rank - 3] as u32,
|
/* v_row_stride */ v_stride[v_rank - 3] as u32,
|
||||||
@ -133,8 +195,10 @@ impl FlashAttn {
|
|||||||
/* seqlen_k */ seqlen_k as u32,
|
/* seqlen_k */ seqlen_k as u32,
|
||||||
/* seqlen_q_rounded */ seqlen_q_rounded as u32,
|
/* seqlen_q_rounded */ seqlen_q_rounded as u32,
|
||||||
/* seqlen_k_rounded */ seqlen_k_rounded as u32,
|
/* seqlen_k_rounded */ seqlen_k_rounded as u32,
|
||||||
/* is_causal */ causal,
|
|
||||||
/* is_bf16 */ is_bf16,
|
/* is_bf16 */ is_bf16,
|
||||||
|
/* is_causal */ is_causal,
|
||||||
|
/* window_size_left */ window_size_left,
|
||||||
|
/* window_size_right */ window_size_right,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -197,20 +261,137 @@ pub fn flash_attn(
|
|||||||
softmax_scale: f32,
|
softmax_scale: f32,
|
||||||
causal: bool,
|
causal: bool,
|
||||||
) -> Result<Tensor> {
|
) -> Result<Tensor> {
|
||||||
|
let window_size_left = None;
|
||||||
|
let window_size_right = if causal { Some(0) } else { None };
|
||||||
|
|
||||||
let op = FlashAttn {
|
let op = FlashAttn {
|
||||||
softmax_scale,
|
softmax_scale,
|
||||||
causal,
|
alibi_slopes: None,
|
||||||
|
window_size_left,
|
||||||
|
window_size_right,
|
||||||
|
};
|
||||||
|
q.apply_op3(k, v, op)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Flash-attention v2 layer.
|
||||||
|
///
|
||||||
|
/// This implements scaled dot-product attention, `softmax(Q @ K^T . softmax_scale) @ V`.
|
||||||
|
/// Multi-query and grouped-query attention are supported by using tensors k and v with fewer heads
|
||||||
|
/// than q, the number of heads in k and v has to be divisible by the number of heads in q.
|
||||||
|
///
|
||||||
|
/// # Arguments
|
||||||
|
///
|
||||||
|
/// * `q` - Query tensor with shape `(batch, seq_len_q, num_heads_q, head_size)`.
|
||||||
|
/// * `k` - Key tensor with shape `(batch, seq_len_kv, num_heads_kv, head_size)`.
|
||||||
|
/// * `v` - Value tensor with shape `(batch, seq_len_kv, num_heads_kv, head_size)`.
|
||||||
|
/// * `window_size_left` - Limit left attention to value tokens.
|
||||||
|
/// * `window_size_right` - Limit right attention to value tokens.
|
||||||
|
///
|
||||||
|
/// # Causal mask
|
||||||
|
///
|
||||||
|
/// `window_size_left=None` with `window_size_right=Some(0)` applies a causal mask to the result
|
||||||
|
/// of `Q @ K^T`
|
||||||
|
///
|
||||||
|
/// The resulting tensor has dimensions `(batch, seq_len_q, num_heads_q, head_size)`.
|
||||||
|
pub fn flash_attn_windowed(
|
||||||
|
q: &Tensor,
|
||||||
|
k: &Tensor,
|
||||||
|
v: &Tensor,
|
||||||
|
softmax_scale: f32,
|
||||||
|
window_size_left: Option<usize>,
|
||||||
|
window_size_right: Option<usize>,
|
||||||
|
) -> Result<Tensor> {
|
||||||
|
let op = FlashAttn {
|
||||||
|
softmax_scale,
|
||||||
|
alibi_slopes: None,
|
||||||
|
window_size_left,
|
||||||
|
window_size_right,
|
||||||
|
};
|
||||||
|
q.apply_op3(k, v, op)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Flash-attention v2 layer.
|
||||||
|
///
|
||||||
|
/// This implements scaled dot-product attention, `softmax(Q @ K^T . softmax_scale) @ V`.
|
||||||
|
/// Multi-query and grouped-query attention are supported by using tensors k and v with fewer heads
|
||||||
|
/// than q, the number of heads in k and v has to be divisible by the number of heads in q.
|
||||||
|
///
|
||||||
|
/// # Arguments
|
||||||
|
///
|
||||||
|
/// * `q` - Query tensor with shape `(batch, seq_len_q, num_heads_q, head_size)`.
|
||||||
|
/// * `k` - Key tensor with shape `(batch, seq_len_kv, num_heads_kv, head_size)`.
|
||||||
|
/// * `v` - Value tensor with shape `(batch, seq_len_kv, num_heads_kv, head_size)`.
|
||||||
|
/// * `alibi_slopes` - Alibi slopes tensor with shape `(num_heads_q)`.
|
||||||
|
///
|
||||||
|
/// The resulting tensor has dimensions `(batch, seq_len_q, num_heads_q, head_size)`.
|
||||||
|
pub fn flash_attn_alibi(
|
||||||
|
q: &Tensor,
|
||||||
|
k: &Tensor,
|
||||||
|
v: &Tensor,
|
||||||
|
alibi_slopes: &Tensor,
|
||||||
|
softmax_scale: f32,
|
||||||
|
causal: bool,
|
||||||
|
) -> Result<Tensor> {
|
||||||
|
let window_size_left = None;
|
||||||
|
let window_size_right = if causal { Some(0) } else { None };
|
||||||
|
|
||||||
|
let op = FlashAttn {
|
||||||
|
softmax_scale,
|
||||||
|
alibi_slopes: Some(alibi_slopes.clone()),
|
||||||
|
window_size_left,
|
||||||
|
window_size_right,
|
||||||
|
};
|
||||||
|
q.apply_op3(k, v, op)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Flash-attention v2 layer.
|
||||||
|
///
|
||||||
|
/// This implements scaled dot-product attention, `softmax(Q @ K^T . softmax_scale) @ V`.
|
||||||
|
/// Multi-query and grouped-query attention are supported by using tensors k and v with fewer heads
|
||||||
|
/// than q, the number of heads in k and v has to be divisible by the number of heads in q.
|
||||||
|
///
|
||||||
|
/// # Arguments
|
||||||
|
///
|
||||||
|
/// * `q` - Query tensor with shape `(batch, seq_len_q, num_heads_q, head_size)`.
|
||||||
|
/// * `k` - Key tensor with shape `(batch, seq_len_kv, num_heads_kv, head_size)`.
|
||||||
|
/// * `v` - Value tensor with shape `(batch, seq_len_kv, num_heads_kv, head_size)`.
|
||||||
|
/// * `alibi_slopes` - Alibi slopes tensor with shape `(num_heads_q)`.
|
||||||
|
/// * `window_size_left` - Limit left attention to value tokens.
|
||||||
|
/// * `window_size_right` - Limit right attention to value tokens.
|
||||||
|
///
|
||||||
|
/// # Causal mask
|
||||||
|
///
|
||||||
|
/// `window_size_left=None` with `window_size_right=Some(0)` applies a causal mask to the result
|
||||||
|
/// of `Q @ K^T`
|
||||||
|
///
|
||||||
|
/// The resulting tensor has dimensions `(batch, seq_len_q, num_heads_q, head_size)`.
|
||||||
|
pub fn flash_attn_alibi_windowed(
|
||||||
|
q: &Tensor,
|
||||||
|
k: &Tensor,
|
||||||
|
v: &Tensor,
|
||||||
|
alibi_slopes: &Tensor,
|
||||||
|
softmax_scale: f32,
|
||||||
|
window_size_left: Option<usize>,
|
||||||
|
window_size_right: Option<usize>,
|
||||||
|
) -> Result<Tensor> {
|
||||||
|
let op = FlashAttn {
|
||||||
|
softmax_scale,
|
||||||
|
alibi_slopes: Some(alibi_slopes.clone()),
|
||||||
|
window_size_left,
|
||||||
|
window_size_right,
|
||||||
};
|
};
|
||||||
q.apply_op3(k, v, op)
|
q.apply_op3(k, v, op)
|
||||||
}
|
}
|
||||||
|
|
||||||
struct FlashAttnVarLen {
|
struct FlashAttnVarLen {
|
||||||
softmax_scale: f32,
|
pub softmax_scale: f32,
|
||||||
causal: bool,
|
pub max_seqlen_q: usize,
|
||||||
max_seqlen_q: usize,
|
pub max_seqlen_k: usize,
|
||||||
max_seqlen_k: usize,
|
pub seqlens_q: Tensor,
|
||||||
seqlens_q: Tensor,
|
pub seqlens_k: Tensor,
|
||||||
seqlens_k: Tensor,
|
pub alibi_slopes: Option<Tensor>,
|
||||||
|
pub window_size_left: Option<usize>,
|
||||||
|
pub window_size_right: Option<usize>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl FlashAttnVarLen {
|
impl FlashAttnVarLen {
|
||||||
@ -311,7 +492,54 @@ impl FlashAttnVarLen {
|
|||||||
if nseqlens_k != nseqlens_q {
|
if nseqlens_k != nseqlens_q {
|
||||||
candle::bail!("seqlens_q and seqlens_k should have the same number of elements {nseqlens_q} <> {nseqlens_k}")
|
candle::bail!("seqlens_q and seqlens_k should have the same number of elements {nseqlens_q} <> {nseqlens_k}")
|
||||||
}
|
}
|
||||||
|
|
||||||
let batch_size = nseqlens_q - 1;
|
let batch_size = nseqlens_q - 1;
|
||||||
|
|
||||||
|
let alibi_slopes_ptr = if let Some(alibi_slopes) = &self.alibi_slopes {
|
||||||
|
if alibi_slopes.dtype() != DType::F32 {
|
||||||
|
candle::bail!(
|
||||||
|
"DType mismatch alibi_slopes {:?}, expected {:?}",
|
||||||
|
alibi_slopes.dtype(),
|
||||||
|
DType::F32
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
let (alibi_slopes, alibi_slopes_layout) = alibi_slopes.storage_and_layout();
|
||||||
|
|
||||||
|
if num_heads != alibi_slopes_layout.shape().dims1()? {
|
||||||
|
candle::bail!(
|
||||||
|
"shape mismatch alibi_slopes {:?}, expected {:?}",
|
||||||
|
alibi_slopes_layout.shape(),
|
||||||
|
(num_heads)
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
let alibi_slopes = match &*alibi_slopes {
|
||||||
|
candle::Storage::Cuda(c) => c.as_cuda_slice::<f32>()?,
|
||||||
|
_ => candle::bail!("alibi_slopes must be a cuda tensor"),
|
||||||
|
};
|
||||||
|
|
||||||
|
let alibi_slopes = alibi_slopes.slice(alibi_slopes_layout.start_offset()..);
|
||||||
|
|
||||||
|
*alibi_slopes.device_ptr() as *const core::ffi::c_void
|
||||||
|
} else {
|
||||||
|
std::ptr::null()
|
||||||
|
};
|
||||||
|
|
||||||
|
// if window_size_left > self.max_seqlen_k or None => -1
|
||||||
|
let mut window_size_left = self
|
||||||
|
.window_size_left
|
||||||
|
.filter(|v| v <= &self.max_seqlen_k)
|
||||||
|
.map(|v| v as i32)
|
||||||
|
.unwrap_or(-1);
|
||||||
|
|
||||||
|
// if window_size_right > self.max_seqlen_k or None => -1
|
||||||
|
let mut window_size_right = self
|
||||||
|
.window_size_right
|
||||||
|
.filter(|v| v <= &self.max_seqlen_k)
|
||||||
|
.map(|v| v as i32)
|
||||||
|
.unwrap_or(-1);
|
||||||
|
|
||||||
let head_size = round_multiple(head_size_og, 8);
|
let head_size = round_multiple(head_size_og, 8);
|
||||||
let head_size_rounded = round_multiple(head_size, 32);
|
let head_size_rounded = round_multiple(head_size, 32);
|
||||||
let seqlen_q_rounded = round_multiple(self.max_seqlen_q, 128);
|
let seqlen_q_rounded = round_multiple(self.max_seqlen_q, 128);
|
||||||
@ -323,9 +551,22 @@ impl FlashAttnVarLen {
|
|||||||
.alloc_zeros::<f32>(batch_size * num_heads * self.max_seqlen_q)
|
.alloc_zeros::<f32>(batch_size * num_heads * self.max_seqlen_q)
|
||||||
.w()?;
|
.w()?;
|
||||||
|
|
||||||
let causal = if self.causal { 1 } else { 0 };
|
|
||||||
let is_bf16 = if is_bf16 { 1 } else { 0 };
|
let is_bf16 = if is_bf16 { 1 } else { 0 };
|
||||||
|
|
||||||
|
// Causal is the special case where window_size_right == 0 and window_size_left < 0.
|
||||||
|
// Local is the more general case where window_size_right >= 0 or window_size_left >= 0.
|
||||||
|
let is_causal = if window_size_left < 0 && window_size_right == 0 {
|
||||||
|
1
|
||||||
|
} else {
|
||||||
|
0
|
||||||
|
};
|
||||||
|
if window_size_left < 0 && window_size_right >= 0 {
|
||||||
|
window_size_left = self.max_seqlen_k as i32;
|
||||||
|
}
|
||||||
|
if window_size_left >= 0 && window_size_right < 0 {
|
||||||
|
window_size_right = self.max_seqlen_k as i32;
|
||||||
|
}
|
||||||
|
|
||||||
unsafe {
|
unsafe {
|
||||||
let q_ptr = *q.device_ptr() as *const core::ffi::c_void;
|
let q_ptr = *q.device_ptr() as *const core::ffi::c_void;
|
||||||
let k_ptr = *k.device_ptr() as *const core::ffi::c_void;
|
let k_ptr = *k.device_ptr() as *const core::ffi::c_void;
|
||||||
@ -340,12 +581,14 @@ impl FlashAttnVarLen {
|
|||||||
v_ptr,
|
v_ptr,
|
||||||
dst_ptr,
|
dst_ptr,
|
||||||
softmax_lse_ptr,
|
softmax_lse_ptr,
|
||||||
|
/* alibi_slopes_ptr */ alibi_slopes_ptr,
|
||||||
/* cu_seqlens_q_ptr */ seqlens_q_ptr,
|
/* cu_seqlens_q_ptr */ seqlens_q_ptr,
|
||||||
/* cu_seqlens_k_ptr */ seqlens_k_ptr,
|
/* cu_seqlens_k_ptr */ seqlens_k_ptr,
|
||||||
/* q_batch_stride */ 0,
|
/* q_batch_stride */ 0,
|
||||||
/* k_batch_stride */ 0,
|
/* k_batch_stride */ 0,
|
||||||
/* v_batch_stride */ 0,
|
/* v_batch_stride */ 0,
|
||||||
/* o_batch_stride */ 0,
|
/* o_batch_stride */ 0,
|
||||||
|
/* alibi_slopes_batch_stride */ 0,
|
||||||
/* q_row_stride */ q_stride[q_rank - 3] as u32,
|
/* q_row_stride */ q_stride[q_rank - 3] as u32,
|
||||||
/* k_row_stride */ k_stride[k_rank - 3] as u32,
|
/* k_row_stride */ k_stride[k_rank - 3] as u32,
|
||||||
/* v_row_stride */ v_stride[v_rank - 3] as u32,
|
/* v_row_stride */ v_stride[v_rank - 3] as u32,
|
||||||
@ -364,8 +607,10 @@ impl FlashAttnVarLen {
|
|||||||
/* seqlen_k */ self.max_seqlen_k as u32,
|
/* seqlen_k */ self.max_seqlen_k as u32,
|
||||||
/* seqlen_q_rounded */ seqlen_q_rounded as u32,
|
/* seqlen_q_rounded */ seqlen_q_rounded as u32,
|
||||||
/* seqlen_k_rounded */ seqlen_k_rounded as u32,
|
/* seqlen_k_rounded */ seqlen_k_rounded as u32,
|
||||||
/* is_causal */ causal,
|
|
||||||
/* is_bf16 */ is_bf16,
|
/* is_bf16 */ is_bf16,
|
||||||
|
/* is_causal */ is_causal,
|
||||||
|
/* window_size_left */ window_size_left,
|
||||||
|
/* window_size_right */ window_size_right,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -440,13 +685,176 @@ pub fn flash_attn_varlen(
|
|||||||
softmax_scale: f32,
|
softmax_scale: f32,
|
||||||
causal: bool,
|
causal: bool,
|
||||||
) -> Result<Tensor> {
|
) -> Result<Tensor> {
|
||||||
|
let window_size_left = None;
|
||||||
|
let window_size_right = if causal { Some(0) } else { None };
|
||||||
|
|
||||||
let op = FlashAttnVarLen {
|
let op = FlashAttnVarLen {
|
||||||
softmax_scale,
|
softmax_scale,
|
||||||
causal,
|
|
||||||
max_seqlen_q,
|
max_seqlen_q,
|
||||||
max_seqlen_k,
|
max_seqlen_k,
|
||||||
seqlens_q: seqlens_q.clone(),
|
seqlens_q: seqlens_q.clone(),
|
||||||
seqlens_k: seqlens_k.clone(),
|
seqlens_k: seqlens_k.clone(),
|
||||||
|
alibi_slopes: None,
|
||||||
|
window_size_left,
|
||||||
|
window_size_right,
|
||||||
|
};
|
||||||
|
q.apply_op3(k, v, op)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[allow(clippy::too_many_arguments)]
|
||||||
|
/// Flash-attention v2 layer with variable-length batching.
|
||||||
|
///
|
||||||
|
/// This implements scaled dot-product attention, `softmax(Q @ K^T . softmax_scale) @ V`.
|
||||||
|
/// Multi-query and grouped-query attention are supported by using tensors k and v with fewer heads
|
||||||
|
/// than q, the number of heads in k and v has to be divisible by the number of heads in q.
|
||||||
|
///
|
||||||
|
/// # Arguments
|
||||||
|
///
|
||||||
|
/// * `q` - Query tensor with shape `(total_q, num_heads_q, head_size)`.
|
||||||
|
/// * `k` - Key tensor with shape `(total_kv, num_heads_kv, head_size)`.
|
||||||
|
/// * `v` - Value tensor with shape `(total_kv, num_heads_kv, head_size)`.
|
||||||
|
/// * `seqlens_q` - The cumulative lengths of the sequences in the batch, used to index in q.
|
||||||
|
/// * `seqlens_k` - The cumulative lengths of the sequences in the batch, used to index in k and v.
|
||||||
|
/// * `max_seqlen_q` - The maximum query sequence length for q in the batch.
|
||||||
|
/// * `max_seqlen_k` - The maximum query sequence length for k and v in the batch.
|
||||||
|
/// * `window_size_left` - Limit left attention to value tokens.
|
||||||
|
/// * `window_size_right` - Limit right attention to value tokens.
|
||||||
|
///
|
||||||
|
/// `seqlens_q` and `seqlens_k` contain `batch_size + 1` elements, typically `0`, `seqlen_1`,
|
||||||
|
/// `seqlen_1 + seqlen_2`, etc.
|
||||||
|
///
|
||||||
|
/// The resulting tensor has dimensions `(total_q, num_heads_q, head_size)`.
|
||||||
|
///
|
||||||
|
/// # Causal mask
|
||||||
|
///
|
||||||
|
/// `window_size_left=None` with `window_size_right=Some(0)` applies a causal mask to the result
|
||||||
|
/// of `Q @ K^T`
|
||||||
|
pub fn flash_attn_varlen_windowed(
|
||||||
|
q: &Tensor,
|
||||||
|
k: &Tensor,
|
||||||
|
v: &Tensor,
|
||||||
|
seqlens_q: &Tensor,
|
||||||
|
seqlens_k: &Tensor,
|
||||||
|
max_seqlen_q: usize,
|
||||||
|
max_seqlen_k: usize,
|
||||||
|
softmax_scale: f32,
|
||||||
|
window_size_left: Option<usize>,
|
||||||
|
window_size_right: Option<usize>,
|
||||||
|
) -> Result<Tensor> {
|
||||||
|
let op = FlashAttnVarLen {
|
||||||
|
softmax_scale,
|
||||||
|
max_seqlen_q,
|
||||||
|
max_seqlen_k,
|
||||||
|
seqlens_q: seqlens_q.clone(),
|
||||||
|
seqlens_k: seqlens_k.clone(),
|
||||||
|
alibi_slopes: None,
|
||||||
|
window_size_left,
|
||||||
|
window_size_right,
|
||||||
|
};
|
||||||
|
q.apply_op3(k, v, op)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[allow(clippy::too_many_arguments)]
|
||||||
|
/// Flash-attention v2 layer with variable-length batching.
|
||||||
|
///
|
||||||
|
/// This implements scaled dot-product attention, `softmax(Q @ K^T . softmax_scale) @ V`.
|
||||||
|
/// Multi-query and grouped-query attention are supported by using tensors k and v with fewer heads
|
||||||
|
/// than q, the number of heads in k and v has to be divisible by the number of heads in q.
|
||||||
|
///
|
||||||
|
/// # Arguments
|
||||||
|
///
|
||||||
|
/// * `q` - Query tensor with shape `(total_q, num_heads_q, head_size)`.
|
||||||
|
/// * `k` - Key tensor with shape `(total_kv, num_heads_kv, head_size)`.
|
||||||
|
/// * `v` - Value tensor with shape `(total_kv, num_heads_kv, head_size)`.
|
||||||
|
/// * `alibi_slopes` - Alibi slopes tensor with shape `(num_heads_q)`.
|
||||||
|
/// * `seqlens_q` - The cumulative lengths of the sequences in the batch, used to index in q.
|
||||||
|
/// * `seqlens_k` - The cumulative lengths of the sequences in the batch, used to index in k and v.
|
||||||
|
/// * `max_seqlen_q` - The maximum query sequence length for q in the batch.
|
||||||
|
/// * `max_seqlen_k` - The maximum query sequence length for k and v in the batch.
|
||||||
|
///
|
||||||
|
/// `seqlens_q` and `seqlens_k` contain `batch_size + 1` elements, typically `0`, `seqlen_1`,
|
||||||
|
/// `seqlen_1 + seqlen_2`, etc.
|
||||||
|
///
|
||||||
|
/// The resulting tensor has dimensions `(total_q, num_heads_q, head_size)`.
|
||||||
|
pub fn flash_attn_varlen_alibi(
|
||||||
|
q: &Tensor,
|
||||||
|
k: &Tensor,
|
||||||
|
v: &Tensor,
|
||||||
|
alibi_slopes: &Tensor,
|
||||||
|
seqlens_q: &Tensor,
|
||||||
|
seqlens_k: &Tensor,
|
||||||
|
max_seqlen_q: usize,
|
||||||
|
max_seqlen_k: usize,
|
||||||
|
softmax_scale: f32,
|
||||||
|
causal: bool,
|
||||||
|
) -> Result<Tensor> {
|
||||||
|
let window_size_left = None;
|
||||||
|
let window_size_right = if causal { Some(0) } else { None };
|
||||||
|
|
||||||
|
let op = FlashAttnVarLen {
|
||||||
|
softmax_scale,
|
||||||
|
max_seqlen_q,
|
||||||
|
max_seqlen_k,
|
||||||
|
seqlens_q: seqlens_q.clone(),
|
||||||
|
seqlens_k: seqlens_k.clone(),
|
||||||
|
alibi_slopes: Some(alibi_slopes.clone()),
|
||||||
|
window_size_left,
|
||||||
|
window_size_right,
|
||||||
|
};
|
||||||
|
q.apply_op3(k, v, op)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[allow(clippy::too_many_arguments)]
|
||||||
|
/// Flash-attention v2 layer with variable-length batching.
|
||||||
|
///
|
||||||
|
/// This implements scaled dot-product attention, `softmax(Q @ K^T . softmax_scale) @ V`.
|
||||||
|
/// Multi-query and grouped-query attention are supported by using tensors k and v with fewer heads
|
||||||
|
/// than q, the number of heads in k and v has to be divisible by the number of heads in q.
|
||||||
|
///
|
||||||
|
/// # Arguments
|
||||||
|
///
|
||||||
|
/// * `q` - Query tensor with shape `(total_q, num_heads_q, head_size)`.
|
||||||
|
/// * `k` - Key tensor with shape `(total_kv, num_heads_kv, head_size)`.
|
||||||
|
/// * `v` - Value tensor with shape `(total_kv, num_heads_kv, head_size)`.
|
||||||
|
/// * `alibi_slopes` - Alibi slopes tensor with shape `(num_heads_q)`.
|
||||||
|
/// * `seqlens_q` - The cumulative lengths of the sequences in the batch, used to index in q.
|
||||||
|
/// * `seqlens_k` - The cumulative lengths of the sequences in the batch, used to index in k and v.
|
||||||
|
/// * `max_seqlen_q` - The maximum query sequence length for q in the batch.
|
||||||
|
/// * `max_seqlen_k` - The maximum query sequence length for k and v in the batch.
|
||||||
|
/// * `window_size_left` - Limit left attention to value tokens.
|
||||||
|
/// * `window_size_right` - Limit right attention to value tokens.
|
||||||
|
///
|
||||||
|
/// `seqlens_q` and `seqlens_k` contain `batch_size + 1` elements, typically `0`, `seqlen_1`,
|
||||||
|
/// `seqlen_1 + seqlen_2`, etc.
|
||||||
|
///
|
||||||
|
/// The resulting tensor has dimensions `(total_q, num_heads_q, head_size)`.
|
||||||
|
///
|
||||||
|
/// # Causal mask
|
||||||
|
///
|
||||||
|
/// `window_size_left=None` with `window_size_right=Some(0)` applies a causal mask to the result
|
||||||
|
/// of `Q @ K^T`
|
||||||
|
pub fn flash_attn_varlen_alibi_windowed(
|
||||||
|
q: &Tensor,
|
||||||
|
k: &Tensor,
|
||||||
|
v: &Tensor,
|
||||||
|
alibi_slopes: &Tensor,
|
||||||
|
seqlens_q: &Tensor,
|
||||||
|
seqlens_k: &Tensor,
|
||||||
|
max_seqlen_q: usize,
|
||||||
|
max_seqlen_k: usize,
|
||||||
|
softmax_scale: f32,
|
||||||
|
window_size_left: Option<usize>,
|
||||||
|
window_size_right: Option<usize>,
|
||||||
|
) -> Result<Tensor> {
|
||||||
|
let op = FlashAttnVarLen {
|
||||||
|
softmax_scale,
|
||||||
|
max_seqlen_q,
|
||||||
|
max_seqlen_k,
|
||||||
|
seqlens_q: seqlens_q.clone(),
|
||||||
|
seqlens_k: seqlens_k.clone(),
|
||||||
|
alibi_slopes: Some(alibi_slopes.clone()),
|
||||||
|
window_size_left,
|
||||||
|
window_size_right,
|
||||||
};
|
};
|
||||||
q.apply_op3(k, v, op)
|
q.apply_op3(k, v, op)
|
||||||
}
|
}
|
||||||
|
@ -56,15 +56,24 @@ kernel void FN_NAME_STRIDED( \
|
|||||||
|
|
||||||
#define BINARY_OP(FN, NAME) \
|
#define BINARY_OP(FN, NAME) \
|
||||||
BINARY(FN, float, float, NAME##_f32, NAME##_f32_strided); \
|
BINARY(FN, float, float, NAME##_f32, NAME##_f32_strided); \
|
||||||
BINARY(FN, half, half, NAME##_f16, NAME##_f16_strided);
|
BINARY(FN, half, half, NAME##_f16, NAME##_f16_strided); \
|
||||||
|
BINARY(FN, uint32_t, uint32_t, NAME##_u32, NAME##_u32_strided); \
|
||||||
|
BINARY(FN, uint8_t, uint8_t, NAME##_u8, NAME##_u8_strided);
|
||||||
|
|
||||||
|
#define INT64_BINARY_OP(NAME, FN) \
|
||||||
|
BINARY(FN, int64_t, int64_t, NAME##_i64, NAME##_i64_strided);
|
||||||
|
|
||||||
#define BFLOAT_BINARY_OP(FN, NAME) \
|
#define BFLOAT_BINARY_OP(FN, NAME) \
|
||||||
BINARY(FN, bfloat, bfloat, NAME##_bf16, NAME##_bf16_strided);
|
BINARY(FN, bfloat, bfloat, NAME##_bf16, NAME##_bf16_strided);
|
||||||
|
|
||||||
#define BINARY_OP_OUT(NAME, FN) \
|
#define BINARY_OP_OUT(NAME, FN) \
|
||||||
BINARY(FN, float, uint8_t, NAME##_f32, NAME##_f32_strided); \
|
BINARY(FN, float, uint8_t, NAME##_f32, NAME##_f32_strided); \
|
||||||
BINARY(FN, half, uint8_t, NAME##_f16, NAME##_f16_strided);
|
BINARY(FN, half, uint8_t, NAME##_f16, NAME##_f16_strided); \
|
||||||
|
BINARY(FN, uint32_t, uint8_t, NAME##_u32, NAME##_u32_strided); \
|
||||||
|
BINARY(FN, uint8_t, uint8_t, NAME##_u8, NAME##_u8_strided);
|
||||||
|
|
||||||
|
#define INT64_BINARY_OP_OUT(NAME, FN) \
|
||||||
|
BINARY(FN, int64_t, int8_t, NAME##_i64, NAME##_i64_strided);
|
||||||
|
|
||||||
BINARY_OP(x + y, add)
|
BINARY_OP(x + y, add)
|
||||||
BINARY_OP(x - y, sub)
|
BINARY_OP(x - y, sub)
|
||||||
@ -80,6 +89,22 @@ BINARY_OP_OUT(lt, x < y)
|
|||||||
BINARY_OP_OUT(ge, x >= y)
|
BINARY_OP_OUT(ge, x >= y)
|
||||||
BINARY_OP_OUT(gt, x > y)
|
BINARY_OP_OUT(gt, x > y)
|
||||||
|
|
||||||
|
#if __METAL_VERSION__ >= 220
|
||||||
|
INT64_BINARY_OP(add, x + y)
|
||||||
|
INT64_BINARY_OP(sub, x - y)
|
||||||
|
INT64_BINARY_OP(mul, x * y)
|
||||||
|
INT64_BINARY_OP(div, x / y)
|
||||||
|
INT64_BINARY_OP(min, MIN(x, y))
|
||||||
|
INT64_BINARY_OP(max, MAX(x, y))
|
||||||
|
|
||||||
|
INT64_BINARY_OP_OUT(eq, x == y)
|
||||||
|
INT64_BINARY_OP_OUT(ne, x != y)
|
||||||
|
INT64_BINARY_OP_OUT(le, x <= y)
|
||||||
|
INT64_BINARY_OP_OUT(lt, x < y)
|
||||||
|
INT64_BINARY_OP_OUT(ge, x >= y)
|
||||||
|
INT64_BINARY_OP_OUT(gt, x > y)
|
||||||
|
#endif
|
||||||
|
|
||||||
#if __METAL_VERSION__ >= 310
|
#if __METAL_VERSION__ >= 310
|
||||||
BFLOAT_BINARY_OP(x + y, add)
|
BFLOAT_BINARY_OP(x + y, add)
|
||||||
BFLOAT_BINARY_OP(x - y, sub)
|
BFLOAT_BINARY_OP(x - y, sub)
|
||||||
|
@ -52,5 +52,13 @@ CAST(cast_u8_f32, cast_u8_f32_strided, uint8_t, float)
|
|||||||
CAST(cast_f16_f32, cast_f16_f32_strided, half, float)
|
CAST(cast_f16_f32, cast_f16_f32_strided, half, float)
|
||||||
CAST(cast_f32_f16, cast_f32_f16_strided, float, half)
|
CAST(cast_f32_f16, cast_f32_f16_strided, float, half)
|
||||||
|
|
||||||
#if __METAL_VERSION__ >= 310
|
#if __METAL_VERSION__ >= 220
|
||||||
|
CAST(cast_u8_i64, cast_u8_i64_strided, uint8_t, int64_t)
|
||||||
|
CAST(cast_u32_i64, cast_u32_i64_strided, uint32_t, int64_t)
|
||||||
|
CAST(cast_i64_f32, cast_i64_f32_strided, int64_t, float)
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#if __METAL_VERSION__ >= 310
|
||||||
|
CAST(cast_bf16_f32, cast_bf16_f32_strided, bfloat, float)
|
||||||
|
CAST(cast_f32_bf16, cast_f32_bf16_strided, float, bfloat)
|
||||||
#endif
|
#endif
|
||||||
|
@ -137,6 +137,9 @@ macro_rules! ops{
|
|||||||
pub const FLOAT: Kernel = Kernel(concat!(stringify!($name), "_f32"));
|
pub const FLOAT: Kernel = Kernel(concat!(stringify!($name), "_f32"));
|
||||||
pub const HALF: Kernel = Kernel(concat!(stringify!($name), "_f16"));
|
pub const HALF: Kernel = Kernel(concat!(stringify!($name), "_f16"));
|
||||||
pub const BFLOAT: Kernel = Kernel(concat!(stringify!($name), "_bf16"));
|
pub const BFLOAT: Kernel = Kernel(concat!(stringify!($name), "_bf16"));
|
||||||
|
pub const I64: Kernel = Kernel(concat!(stringify!($name), "_i64"));
|
||||||
|
pub const U32: Kernel = Kernel(concat!(stringify!($name), "_u32"));
|
||||||
|
pub const U8: Kernel = Kernel(concat!(stringify!($name), "_u8"));
|
||||||
}
|
}
|
||||||
)+
|
)+
|
||||||
pub mod copy {
|
pub mod copy {
|
||||||
@ -144,6 +147,7 @@ macro_rules! ops{
|
|||||||
pub const FLOAT: Kernel = Kernel("copy_f32");
|
pub const FLOAT: Kernel = Kernel("copy_f32");
|
||||||
pub const HALF: Kernel = Kernel("copy_f16");
|
pub const HALF: Kernel = Kernel("copy_f16");
|
||||||
pub const BFLOAT: Kernel = Kernel("copy_bf16");
|
pub const BFLOAT: Kernel = Kernel("copy_bf16");
|
||||||
|
pub const I64: Kernel = Kernel("copy_i64");
|
||||||
pub const U32: Kernel = Kernel("copy_u32");
|
pub const U32: Kernel = Kernel("copy_u32");
|
||||||
pub const U8: Kernel = Kernel("copy_u8");
|
pub const U8: Kernel = Kernel("copy_u8");
|
||||||
}
|
}
|
||||||
@ -157,6 +161,9 @@ macro_rules! ops{
|
|||||||
pub const FLOAT: Kernel = Kernel(concat!(stringify!($name), "_f32_strided"));
|
pub const FLOAT: Kernel = Kernel(concat!(stringify!($name), "_f32_strided"));
|
||||||
pub const HALF: Kernel = Kernel(concat!(stringify!($name), "_f16_strided"));
|
pub const HALF: Kernel = Kernel(concat!(stringify!($name), "_f16_strided"));
|
||||||
pub const BFLOAT: Kernel = Kernel(concat!(stringify!($name), "_bf16_strided"));
|
pub const BFLOAT: Kernel = Kernel(concat!(stringify!($name), "_bf16_strided"));
|
||||||
|
pub const I64: Kernel = Kernel(concat!(stringify!($name), "_i64_strided"));
|
||||||
|
pub const U32: Kernel = Kernel(concat!(stringify!($name), "_u32_strided"));
|
||||||
|
pub const U8: Kernel = Kernel(concat!(stringify!($name), "_u8_strided"));
|
||||||
}
|
}
|
||||||
)+
|
)+
|
||||||
pub mod copy {
|
pub mod copy {
|
||||||
@ -164,6 +171,7 @@ macro_rules! ops{
|
|||||||
pub const FLOAT: Kernel = Kernel("copy_f32_strided");
|
pub const FLOAT: Kernel = Kernel("copy_f32_strided");
|
||||||
pub const HALF: Kernel = Kernel("copy_f16_strided");
|
pub const HALF: Kernel = Kernel("copy_f16_strided");
|
||||||
pub const BFLOAT: Kernel = Kernel("copy_bf16_strided");
|
pub const BFLOAT: Kernel = Kernel("copy_bf16_strided");
|
||||||
|
pub const I64: Kernel = Kernel("copy_i64_strided");
|
||||||
pub const U32: Kernel = Kernel("copy_u32_strided");
|
pub const U32: Kernel = Kernel("copy_u32_strided");
|
||||||
pub const U8: Kernel = Kernel("copy_u8_strided");
|
pub const U8: Kernel = Kernel("copy_u8_strided");
|
||||||
}
|
}
|
||||||
@ -172,7 +180,10 @@ macro_rules! ops{
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub mod unary {
|
pub mod unary {
|
||||||
ops!(cos, sin, exp, sqr, sqrt, neg, log, gelu, ceil, floor, round, erf, gelu_erf, tanh);
|
ops!(
|
||||||
|
cos, sin, exp, sqr, sqrt, neg, log, gelu, abs, ceil, floor, round, erf, gelu_erf, tanh,
|
||||||
|
recip
|
||||||
|
);
|
||||||
}
|
}
|
||||||
pub mod binary {
|
pub mod binary {
|
||||||
ops!(add, sub, mul, div, min, max, eq, ne, le, lt, ge, gt);
|
ops!(add, sub, mul, div, min, max, eq, ne, le, lt, ge, gt);
|
||||||
|
@ -263,24 +263,38 @@ kernel void NAME(
|
|||||||
REDUCE(x + y, fast_sum_f32_strided, float, 0)
|
REDUCE(x + y, fast_sum_f32_strided, float, 0)
|
||||||
REDUCE(x + y, fast_sum_u32_strided, uint, 0)
|
REDUCE(x + y, fast_sum_u32_strided, uint, 0)
|
||||||
REDUCE(x + y, fast_sum_f16_strided, half, 0)
|
REDUCE(x + y, fast_sum_f16_strided, half, 0)
|
||||||
|
REDUCE(x + y, fast_sum_u8_strided, uint8_t, 0)
|
||||||
REDUCE(x * y, fast_mul_f32_strided, float, 1)
|
REDUCE(x * y, fast_mul_f32_strided, float, 1)
|
||||||
REDUCE(x * y, fast_mul_u32_strided, uint, 1)
|
REDUCE(x * y, fast_mul_u32_strided, uint, 1)
|
||||||
REDUCE(x * y, fast_mul_f16_strided, half, 1)
|
REDUCE(x * y, fast_mul_f16_strided, half, 1)
|
||||||
REDUCE(MAX(x, y), fast_max_f32_strided, float, -HUGE_VALF)
|
REDUCE(MAX(x, y), fast_max_f32_strided, float, -HUGE_VALF)
|
||||||
REDUCE(MAX(x, y), fast_max_u32_strided, uint, 0)
|
REDUCE(MAX(x, y), fast_max_u32_strided, uint, 0)
|
||||||
REDUCE(MAX(x, y), fast_max_f16_strided, half, -HUGE_VALH)
|
REDUCE(MAX(x, y), fast_max_f16_strided, half, -HUGE_VALH)
|
||||||
|
REDUCE(MAX(x, y), fast_max_u8_strided, uint8_t, 0)
|
||||||
REDUCE(MIN(x, y), fast_min_f32_strided, float, HUGE_VALF)
|
REDUCE(MIN(x, y), fast_min_f32_strided, float, HUGE_VALF)
|
||||||
REDUCE(MIN(x, y), fast_min_u32_strided, uint, 0xFFFFFFFF)
|
REDUCE(MIN(x, y), fast_min_u32_strided, uint, 0xFFFFFFFF)
|
||||||
REDUCE(MIN(x, y), fast_min_f16_strided, half, HUGE_VALH)
|
REDUCE(MIN(x, y), fast_min_f16_strided, half, HUGE_VALH)
|
||||||
|
REDUCE(MIN(x, y), fast_min_u8_strided, uint8_t, 0xFF)
|
||||||
ARGMIN(fast_argmin_f32_strided, float, HUGE_VALF)
|
ARGMIN(fast_argmin_f32_strided, float, HUGE_VALF)
|
||||||
ARGMIN(fast_argmin_f16_strided, half, HUGE_VALH)
|
ARGMIN(fast_argmin_f16_strided, half, HUGE_VALH)
|
||||||
ARGMIN(fast_argmin_u32_strided, uint, 0xFFFFFFFF)
|
ARGMIN(fast_argmin_u32_strided, uint, 0xFFFFFFFF)
|
||||||
|
ARGMIN(fast_argmin_u8_strided, uint8_t, 0xFF)
|
||||||
ARGMAX(fast_argmax_f32_strided, float, -HUGE_VALF)
|
ARGMAX(fast_argmax_f32_strided, float, -HUGE_VALF)
|
||||||
ARGMAX(fast_argmax_f16_strided, half, -HUGE_VALH)
|
ARGMAX(fast_argmax_f16_strided, half, -HUGE_VALH)
|
||||||
ARGMAX(fast_argmax_u32_strided, uint, 0)
|
ARGMAX(fast_argmax_u32_strided, uint, 0)
|
||||||
|
ARGMAX(fast_argmax_u8_strided, uint8_t, 0)
|
||||||
|
|
||||||
SOFTMAX(softmax_f32, float)
|
SOFTMAX(softmax_f32, float)
|
||||||
SOFTMAX(softmax_f16, half)
|
SOFTMAX(softmax_f16, half)
|
||||||
|
|
||||||
|
#if __METAL_VERSION__ >= 220
|
||||||
|
REDUCE(x + y, fast_sum_i64_strided, int64_t, 0)
|
||||||
|
REDUCE(MIN(x, y), fast_min_i64_strided, int64_t, INT_MAX)
|
||||||
|
REDUCE(MAX(x, y), fast_max_i64_strided, int64_t, INT_MIN)
|
||||||
|
ARGMIN(fast_argmin_i64_strided, int64_t, INT_MAX)
|
||||||
|
ARGMAX(fast_argmax_i64_strided, int64_t, INT_MIN)
|
||||||
|
#endif
|
||||||
|
|
||||||
#if __METAL_VERSION__ >= 310
|
#if __METAL_VERSION__ >= 310
|
||||||
REDUCE(x + y, fast_sum_bf16, bfloat, 0)
|
REDUCE(x + y, fast_sum_bf16, bfloat, 0)
|
||||||
REDUCE(x * y, fast_mul_bf16, bfloat, 1)
|
REDUCE(x * y, fast_mul_bf16, bfloat, 1)
|
||||||
|
@ -55,6 +55,9 @@ kernel void FN_NAME( \
|
|||||||
|
|
||||||
WHERE_OP(float, uint8_t, where_u8_f32)
|
WHERE_OP(float, uint8_t, where_u8_f32)
|
||||||
// WHERE_OP(double, uint8_t, where_u8_f64)
|
// WHERE_OP(double, uint8_t, where_u8_f64)
|
||||||
// WHERE_OP(uint8_t, uint8_t, where_u8_u8)
|
WHERE_OP(uint8_t, uint8_t, where_u8_u8)
|
||||||
// WHERE_OP(uint32_t, uint8_t, where_u8_u32)
|
WHERE_OP(uint32_t, uint8_t, where_u8_u32)
|
||||||
// WHERE_OP(int64_t, uint8_t, where_u8_i64)
|
|
||||||
|
#if __METAL_VERSION__ >= 220
|
||||||
|
WHERE_OP(int64_t, uint8_t, where_u8_i64)
|
||||||
|
#endif
|
||||||
|
@ -18,7 +18,9 @@ METAL_FUNC uint get_strided_index(
|
|||||||
}
|
}
|
||||||
|
|
||||||
template <typename T> METAL_FUNC T sqr(T in){ return in * in; }
|
template <typename T> METAL_FUNC T sqr(T in){ return in * in; }
|
||||||
|
template <typename T> METAL_FUNC T recip(T in){ return T(1.0 / in); }
|
||||||
template <typename T> METAL_FUNC T neg(T in){ return -in; }
|
template <typename T> METAL_FUNC T neg(T in){ return -in; }
|
||||||
|
|
||||||
template <typename T> METAL_FUNC T erf(T in){
|
template <typename T> METAL_FUNC T erf(T in){
|
||||||
float x = (float) in;
|
float x = (float) in;
|
||||||
// constants
|
// constants
|
||||||
@ -56,8 +58,6 @@ template <typename T> METAL_FUNC T gelu(T x) {
|
|||||||
return static_cast<T>(0.5) * x * (static_cast<T>(1.0) + T(tanh(beta)));
|
return static_cast<T>(0.5) * x * (static_cast<T>(1.0) + T(tanh(beta)));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
#define UNARY(FN, TYPENAME, FN_NAME, FN_NAME_STRIDED) \
|
#define UNARY(FN, TYPENAME, FN_NAME, FN_NAME_STRIDED) \
|
||||||
kernel void FN_NAME( \
|
kernel void FN_NAME( \
|
||||||
constant size_t &dim, \
|
constant size_t &dim, \
|
||||||
@ -101,17 +101,24 @@ UNARY_OP(neg)
|
|||||||
UNARY_OP(exp)
|
UNARY_OP(exp)
|
||||||
UNARY_OP(log)
|
UNARY_OP(log)
|
||||||
UNARY_OP(gelu)
|
UNARY_OP(gelu)
|
||||||
|
UNARY_OP(abs)
|
||||||
UNARY_OP(ceil)
|
UNARY_OP(ceil)
|
||||||
UNARY_OP(floor)
|
UNARY_OP(floor)
|
||||||
UNARY_OP(round)
|
UNARY_OP(round)
|
||||||
UNARY_OP(gelu_erf)
|
UNARY_OP(gelu_erf)
|
||||||
UNARY_OP(erf)
|
UNARY_OP(erf)
|
||||||
UNARY_OP(tanh)
|
UNARY_OP(tanh)
|
||||||
|
UNARY_OP(recip)
|
||||||
|
|
||||||
UNARY(id, float, copy_f32, copy_f32_strided)
|
UNARY(id, float, copy_f32, copy_f32_strided)
|
||||||
UNARY(id, half, copy_f16, copy_f16_strided)
|
UNARY(id, half, copy_f16, copy_f16_strided)
|
||||||
UNARY(id, uint8_t, copy_u8, copy_u8_strided)
|
UNARY(id, uint8_t, copy_u8, copy_u8_strided)
|
||||||
UNARY(id, uint32_t, copy_u32, copy_u32_strided)
|
UNARY(id, uint32_t, copy_u32, copy_u32_strided)
|
||||||
|
|
||||||
|
#if __METAL_VERSION__ >= 220
|
||||||
|
UNARY(id, int64_t, copy_i64, copy_i64_strided)
|
||||||
|
#endif
|
||||||
|
|
||||||
#if __METAL_VERSION__ >= 310
|
#if __METAL_VERSION__ >= 310
|
||||||
BFLOAT_UNARY_OP(cos)
|
BFLOAT_UNARY_OP(cos)
|
||||||
BFLOAT_UNARY_OP(sin)
|
BFLOAT_UNARY_OP(sin)
|
||||||
@ -127,6 +134,7 @@ BFLOAT_UNARY_OP(round)
|
|||||||
BFLOAT_UNARY_OP(gelu_erf)
|
BFLOAT_UNARY_OP(gelu_erf)
|
||||||
BFLOAT_UNARY_OP(erf)
|
BFLOAT_UNARY_OP(erf)
|
||||||
BFLOAT_UNARY_OP(tanh)
|
BFLOAT_UNARY_OP(tanh)
|
||||||
|
BFLOAT_UNARY_OP(recip)
|
||||||
|
|
||||||
UNARY(id, bfloat, copy_bf16, copy_bf16_strided)
|
UNARY(id, bfloat, copy_bf16, copy_bf16_strided)
|
||||||
#endif
|
#endif
|
||||||
|
@ -7,15 +7,21 @@
|
|||||||
//! running stats.
|
//! running stats.
|
||||||
//!
|
//!
|
||||||
//! [`Batch Normalization`]: https://arxiv.org/abs/1502.03167
|
//! [`Batch Normalization`]: https://arxiv.org/abs/1502.03167
|
||||||
use candle::{DType, Result, Tensor};
|
use candle::{DType, Result, Tensor, Var};
|
||||||
|
|
||||||
#[derive(Debug, Clone, Copy, PartialEq)]
|
#[derive(Debug, Clone, Copy, PartialEq)]
|
||||||
pub struct BatchNormConfig {
|
pub struct BatchNormConfig {
|
||||||
pub eps: f64,
|
pub eps: f64,
|
||||||
pub remove_mean: bool,
|
pub remove_mean: bool,
|
||||||
|
|
||||||
/// The meaning of affine here is different from LayerNorm: when false there is no learnable
|
/// The meaning of affine here is different from LayerNorm: when false there is no learnable
|
||||||
/// parameter at all, 1 used for gamma and 0 for beta.
|
/// parameter at all, 1 used for gamma and 0 for beta.
|
||||||
pub affine: bool,
|
pub affine: bool,
|
||||||
|
|
||||||
|
/// Controls exponential moving average of running stats. Defaults to 0.1
|
||||||
|
///
|
||||||
|
/// `running_stat * (1.0 - momentum) + stat * momentum`.
|
||||||
|
pub momentum: f64,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Default for BatchNormConfig {
|
impl Default for BatchNormConfig {
|
||||||
@ -24,6 +30,7 @@ impl Default for BatchNormConfig {
|
|||||||
eps: 1e-5,
|
eps: 1e-5,
|
||||||
remove_mean: true,
|
remove_mean: true,
|
||||||
affine: true,
|
affine: true,
|
||||||
|
momentum: 0.1,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -32,23 +39,61 @@ impl From<f64> for BatchNormConfig {
|
|||||||
fn from(eps: f64) -> Self {
|
fn from(eps: f64) -> Self {
|
||||||
Self {
|
Self {
|
||||||
eps,
|
eps,
|
||||||
remove_mean: true,
|
..Default::default()
|
||||||
affine: true,
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone, Debug)]
|
#[derive(Clone, Debug)]
|
||||||
pub struct BatchNorm {
|
pub struct BatchNorm {
|
||||||
running_mean: Tensor,
|
running_mean: Var,
|
||||||
running_var: Tensor,
|
running_var: Var,
|
||||||
weight_and_bias: Option<(Tensor, Tensor)>,
|
weight_and_bias: Option<(Tensor, Tensor)>,
|
||||||
remove_mean: bool,
|
remove_mean: bool,
|
||||||
eps: f64,
|
eps: f64,
|
||||||
num_features: usize,
|
momentum: f64,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl BatchNorm {
|
impl BatchNorm {
|
||||||
|
fn check_validity(&self, num_features: usize) -> Result<()> {
|
||||||
|
if self.eps < 0. {
|
||||||
|
candle::bail!("batch-norm eps cannot be negative {}", self.eps)
|
||||||
|
}
|
||||||
|
if !(0.0..=1.0).contains(&self.momentum) {
|
||||||
|
candle::bail!(
|
||||||
|
"batch-norm momentum must be between 0 and 1, is {}",
|
||||||
|
self.momentum
|
||||||
|
)
|
||||||
|
}
|
||||||
|
if self.running_mean.dims() != [num_features] {
|
||||||
|
candle::bail!(
|
||||||
|
"batch-norm running mean has unexpected shape {:?} should have shape [{num_features}]",
|
||||||
|
self.running_mean.shape(),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
if self.running_var.dims() != [num_features] {
|
||||||
|
candle::bail!(
|
||||||
|
"batch-norm running variance has unexpected shape {:?} should have shape [{num_features}]",
|
||||||
|
self.running_var.shape(),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
if let Some((ref weight, ref bias)) = self.weight_and_bias.as_ref() {
|
||||||
|
if weight.dims() != [num_features] {
|
||||||
|
candle::bail!(
|
||||||
|
"batch-norm weight has unexpected shape {:?} should have shape [{num_features}]",
|
||||||
|
weight.shape(),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
if bias.dims() != [num_features] {
|
||||||
|
candle::bail!(
|
||||||
|
"batch-norm weight has unexpected shape {:?} should have shape [{num_features}]",
|
||||||
|
bias.shape(),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
pub fn new(
|
pub fn new(
|
||||||
num_features: usize,
|
num_features: usize,
|
||||||
running_mean: Tensor,
|
running_mean: Tensor,
|
||||||
@ -57,29 +102,16 @@ impl BatchNorm {
|
|||||||
bias: Tensor,
|
bias: Tensor,
|
||||||
eps: f64,
|
eps: f64,
|
||||||
) -> Result<Self> {
|
) -> Result<Self> {
|
||||||
if eps < 0. {
|
let out = Self {
|
||||||
candle::bail!("batch-norm eps cannot be negative {eps}")
|
running_mean: Var::from_tensor(&running_mean)?,
|
||||||
}
|
running_var: Var::from_tensor(&running_var)?,
|
||||||
if weight.dims() != [num_features] {
|
|
||||||
candle::bail!(
|
|
||||||
"batch-norm unexpected weight shape {:?} {num_features}",
|
|
||||||
weight.shape()
|
|
||||||
)
|
|
||||||
}
|
|
||||||
if bias.dims() != [num_features] {
|
|
||||||
candle::bail!(
|
|
||||||
"batch-norm unexpected bias shape {:?} {num_features}",
|
|
||||||
bias.shape()
|
|
||||||
)
|
|
||||||
}
|
|
||||||
Ok(Self {
|
|
||||||
running_mean,
|
|
||||||
running_var,
|
|
||||||
weight_and_bias: Some((weight, bias)),
|
weight_and_bias: Some((weight, bias)),
|
||||||
remove_mean: true,
|
remove_mean: true,
|
||||||
eps,
|
eps,
|
||||||
num_features,
|
momentum: 0.1,
|
||||||
})
|
};
|
||||||
|
out.check_validity(num_features)?;
|
||||||
|
Ok(out)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn new_no_bias(
|
pub fn new_no_bias(
|
||||||
@ -88,25 +120,64 @@ impl BatchNorm {
|
|||||||
running_var: Tensor,
|
running_var: Tensor,
|
||||||
eps: f64,
|
eps: f64,
|
||||||
) -> Result<Self> {
|
) -> Result<Self> {
|
||||||
if eps < 0. {
|
let out = Self {
|
||||||
candle::bail!("batch-norm eps cannot be negative {eps}")
|
running_mean: Var::from_tensor(&running_mean)?,
|
||||||
}
|
running_var: Var::from_tensor(&running_var)?,
|
||||||
Ok(Self {
|
|
||||||
running_mean,
|
|
||||||
running_var,
|
|
||||||
weight_and_bias: None,
|
weight_and_bias: None,
|
||||||
remove_mean: true,
|
remove_mean: true,
|
||||||
eps,
|
eps,
|
||||||
num_features,
|
momentum: 0.1,
|
||||||
})
|
};
|
||||||
|
out.check_validity(num_features)?;
|
||||||
|
Ok(out)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn new_with_momentum(
|
||||||
|
num_features: usize,
|
||||||
|
running_mean: Tensor,
|
||||||
|
running_var: Tensor,
|
||||||
|
weight: Tensor,
|
||||||
|
bias: Tensor,
|
||||||
|
eps: f64,
|
||||||
|
momentum: f64,
|
||||||
|
) -> Result<Self> {
|
||||||
|
let out = Self {
|
||||||
|
running_mean: Var::from_tensor(&running_mean)?,
|
||||||
|
running_var: Var::from_tensor(&running_var)?,
|
||||||
|
weight_and_bias: Some((weight, bias)),
|
||||||
|
remove_mean: true,
|
||||||
|
eps,
|
||||||
|
momentum,
|
||||||
|
};
|
||||||
|
out.check_validity(num_features)?;
|
||||||
|
Ok(out)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn new_no_bias_with_momentum(
|
||||||
|
num_features: usize,
|
||||||
|
running_mean: Tensor,
|
||||||
|
running_var: Tensor,
|
||||||
|
eps: f64,
|
||||||
|
momentum: f64,
|
||||||
|
) -> Result<Self> {
|
||||||
|
let out = Self {
|
||||||
|
running_mean: Var::from_tensor(&running_mean)?,
|
||||||
|
running_var: Var::from_tensor(&running_var)?,
|
||||||
|
weight_and_bias: None,
|
||||||
|
remove_mean: true,
|
||||||
|
eps,
|
||||||
|
momentum,
|
||||||
|
};
|
||||||
|
out.check_validity(num_features)?;
|
||||||
|
Ok(out)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn running_mean(&self) -> &Tensor {
|
pub fn running_mean(&self) -> &Tensor {
|
||||||
&self.running_mean
|
self.running_mean.as_tensor()
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn running_var(&self) -> &Tensor {
|
pub fn running_var(&self) -> &Tensor {
|
||||||
&self.running_var
|
self.running_var.as_tensor()
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn eps(&self) -> f64 {
|
pub fn eps(&self) -> f64 {
|
||||||
@ -117,7 +188,12 @@ impl BatchNorm {
|
|||||||
self.weight_and_bias.as_ref().map(|v| (&v.0, &v.1))
|
self.weight_and_bias.as_ref().map(|v| (&v.0, &v.1))
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn forward_learning(&self, x: &Tensor) -> Result<Tensor> {
|
pub fn momentum(&self) -> f64 {
|
||||||
|
self.momentum
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn forward_train(&self, x: &Tensor) -> Result<Tensor> {
|
||||||
|
let num_features = self.running_mean.as_tensor().dim(0)?;
|
||||||
let x_dtype = x.dtype();
|
let x_dtype = x.dtype();
|
||||||
let internal_dtype = match x_dtype {
|
let internal_dtype = match x_dtype {
|
||||||
DType::F16 | DType::BF16 => DType::F32,
|
DType::F16 | DType::BF16 => DType::F32,
|
||||||
@ -129,40 +205,54 @@ impl BatchNorm {
|
|||||||
x.shape()
|
x.shape()
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
if x.dim(1)? != self.num_features {
|
if x.dim(1)? != num_features {
|
||||||
candle::bail!(
|
candle::bail!(
|
||||||
"batch-norm input doesn't have the expected number of features ({:?} <> {})",
|
"batch-norm input doesn't have the expected number of features ({:?} <> {})",
|
||||||
x.shape(),
|
x.shape(),
|
||||||
self.num_features
|
num_features
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
let x = x.to_dtype(internal_dtype)?;
|
let x = x.to_dtype(internal_dtype)?;
|
||||||
let x = x.transpose(0, 1)?;
|
let x = x.transpose(0, 1)?;
|
||||||
let x_dims_post_transpose = x.dims();
|
let x_dims_post_transpose = x.dims();
|
||||||
|
// Flatten all the dimensions exception the channel one as this performs a Spatial Batch
|
||||||
|
// Normalization.
|
||||||
let x = x.flatten_from(1)?.contiguous()?;
|
let x = x.flatten_from(1)?.contiguous()?;
|
||||||
let x = if self.remove_mean {
|
let x = if self.remove_mean {
|
||||||
|
// The mean is taken over dim 1 as this is the batch dim after the transpose(0, 1) above.
|
||||||
let mean_x = x.mean_keepdim(1)?;
|
let mean_x = x.mean_keepdim(1)?;
|
||||||
|
let updated_running_mean = ((self.running_mean.as_tensor() * (1.0 - self.momentum))?
|
||||||
|
+ (mean_x.flatten_all()? * self.momentum)?)?;
|
||||||
|
self.running_mean.set(&updated_running_mean)?;
|
||||||
x.broadcast_sub(&mean_x)?
|
x.broadcast_sub(&mean_x)?
|
||||||
} else {
|
} else {
|
||||||
x
|
x
|
||||||
};
|
};
|
||||||
|
// The mean is taken over dim 1 as this is the batch dim after the transpose(0, 1) above.
|
||||||
let norm_x = x.sqr()?.mean_keepdim(1)?;
|
let norm_x = x.sqr()?.mean_keepdim(1)?;
|
||||||
let x_normed = x.broadcast_div(&(norm_x + self.eps)?.sqrt()?)?;
|
let updated_running_var = {
|
||||||
let x = x_normed.to_dtype(x_dtype)?;
|
let batch_size = x.dim(1)? as f64;
|
||||||
|
let running_var_weight = 1.0 - self.momentum;
|
||||||
|
let norm_x_weight = self.momentum * batch_size / (batch_size - 1.0);
|
||||||
|
((self.running_var.as_tensor() * running_var_weight)?
|
||||||
|
+ (&norm_x.flatten_all()? * norm_x_weight)?)?
|
||||||
|
};
|
||||||
|
self.running_var.set(&updated_running_var)?;
|
||||||
|
let x = x
|
||||||
|
.broadcast_div(&(norm_x + self.eps)?.sqrt()?)?
|
||||||
|
.to_dtype(x_dtype)?;
|
||||||
let x = match &self.weight_and_bias {
|
let x = match &self.weight_and_bias {
|
||||||
None => x,
|
None => x,
|
||||||
Some((weight, bias)) => {
|
Some((weight, bias)) => {
|
||||||
let weight = weight.reshape((self.num_features, 1))?;
|
let weight = weight.reshape(((), 1))?;
|
||||||
let bias = bias.reshape((self.num_features, 1))?;
|
let bias = bias.reshape(((), 1))?;
|
||||||
x.broadcast_mul(&weight)?.broadcast_add(&bias)?
|
x.broadcast_mul(&weight)?.broadcast_add(&bias)?
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
x.reshape(x_dims_post_transpose)?.transpose(0, 1)
|
x.reshape(x_dims_post_transpose)?.transpose(0, 1)
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
impl crate::Module for BatchNorm {
|
fn forward_eval(&self, x: &Tensor) -> Result<Tensor> {
|
||||||
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
|
||||||
let target_shape: Vec<usize> = x
|
let target_shape: Vec<usize> = x
|
||||||
.dims()
|
.dims()
|
||||||
.iter()
|
.iter()
|
||||||
@ -170,9 +260,13 @@ impl crate::Module for BatchNorm {
|
|||||||
.map(|(idx, v)| if idx == 1 { *v } else { 1 })
|
.map(|(idx, v)| if idx == 1 { *v } else { 1 })
|
||||||
.collect();
|
.collect();
|
||||||
let target_shape = target_shape.as_slice();
|
let target_shape = target_shape.as_slice();
|
||||||
|
|
||||||
let x = x
|
let x = x
|
||||||
.broadcast_sub(&self.running_mean.reshape(target_shape)?)?
|
.broadcast_sub(&self.running_mean.as_tensor().reshape(target_shape)?)?
|
||||||
.broadcast_div(&(self.running_var.reshape(target_shape)? + self.eps)?.sqrt()?)?;
|
.broadcast_div(
|
||||||
|
&(self.running_var.as_tensor().reshape(target_shape)? + self.eps)?.sqrt()?,
|
||||||
|
)?;
|
||||||
|
|
||||||
match &self.weight_and_bias {
|
match &self.weight_and_bias {
|
||||||
None => Ok(x),
|
None => Ok(x),
|
||||||
Some((weight, bias)) => {
|
Some((weight, bias)) => {
|
||||||
@ -184,30 +278,41 @@ impl crate::Module for BatchNorm {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl crate::ModuleT for BatchNorm {
|
||||||
|
fn forward_t(&self, x: &Tensor, train: bool) -> Result<Tensor> {
|
||||||
|
if train {
|
||||||
|
self.forward_train(x)
|
||||||
|
} else {
|
||||||
|
self.forward_eval(x)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub fn batch_norm<C: Into<BatchNormConfig>>(
|
pub fn batch_norm<C: Into<BatchNormConfig>>(
|
||||||
num_features: usize,
|
num_features: usize,
|
||||||
config: C,
|
config: C,
|
||||||
vb: crate::VarBuilder,
|
vb: crate::VarBuilder,
|
||||||
) -> Result<BatchNorm> {
|
) -> Result<BatchNorm> {
|
||||||
|
use crate::Init;
|
||||||
let config = config.into();
|
let config = config.into();
|
||||||
if config.eps < 0. {
|
if config.eps < 0. {
|
||||||
candle::bail!("batch-norm eps cannot be negative {}", config.eps)
|
candle::bail!("batch-norm eps cannot be negative {}", config.eps)
|
||||||
}
|
}
|
||||||
let running_mean = vb.get_with_hints(num_features, "running_mean", crate::Init::Const(0.))?;
|
let running_mean = vb.get_with_hints(num_features, "running_mean", Init::Const(0.))?;
|
||||||
let running_var = vb.get_with_hints(num_features, "running_var", crate::Init::Const(1.))?;
|
let running_var = vb.get_with_hints(num_features, "running_var", Init::Const(1.))?;
|
||||||
let weight_and_bias = if config.affine {
|
let weight_and_bias = if config.affine {
|
||||||
let weight = vb.get_with_hints(num_features, "weight", crate::Init::Const(1.))?;
|
let weight = vb.get_with_hints(num_features, "weight", Init::Const(1.))?;
|
||||||
let bias = vb.get_with_hints(num_features, "bias", crate::Init::Const(0.))?;
|
let bias = vb.get_with_hints(num_features, "bias", Init::Const(0.))?;
|
||||||
Some((weight, bias))
|
Some((weight, bias))
|
||||||
} else {
|
} else {
|
||||||
None
|
None
|
||||||
};
|
};
|
||||||
Ok(BatchNorm {
|
Ok(BatchNorm {
|
||||||
running_mean,
|
running_mean: Var::from_tensor(&running_mean)?,
|
||||||
running_var,
|
running_var: Var::from_tensor(&running_var)?,
|
||||||
weight_and_bias,
|
weight_and_bias,
|
||||||
remove_mean: config.remove_mean,
|
remove_mean: config.remove_mean,
|
||||||
eps: config.eps,
|
eps: config.eps,
|
||||||
num_features,
|
momentum: config.momentum,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
150
candle-nn/src/encoding.rs
Normal file
150
candle-nn/src/encoding.rs
Normal file
@ -0,0 +1,150 @@
|
|||||||
|
//! Encoding Utilities. (e.g., one-hot/cold encoding)
|
||||||
|
|
||||||
|
use candle::{bail, DType, Result, Tensor, WithDType};
|
||||||
|
|
||||||
|
/// One-hot/cold encoding.
|
||||||
|
///
|
||||||
|
/// Given an input tensor of indices, this function returns a tensor of the same shape as the input
|
||||||
|
/// tensor with an additional dimension of the given depth size. The values in the returned tensor are
|
||||||
|
/// all set to the `off_value` except for the positions represented by the indices, which are set to the `on_value`.
|
||||||
|
///
|
||||||
|
/// This method returns a tensor with a rank that is one rank larger than the input tensor.
|
||||||
|
///
|
||||||
|
/// As an example, the following tensor will be encoded to a one-hot matrix:
|
||||||
|
///
|
||||||
|
/// `[[0i64, 2], [1, -1]]`
|
||||||
|
///
|
||||||
|
/// with a depth of 4 will be encoded to:
|
||||||
|
///
|
||||||
|
/// `[[[1, 0, 0, 0], [0, 0, 1, 0]], [[0, 1, 0, 0], [0, 0, 0, 0]]]`
|
||||||
|
///
|
||||||
|
/// When the input tensor index has a value of -1, the corresponding one-hot vector will be ignored,
|
||||||
|
/// resulting in a vector of values set to the `off_value`.
|
||||||
|
///
|
||||||
|
///
|
||||||
|
/// This method supports one-cold encoding by setting `on_value` to `0` and `off_value` to `1`.
|
||||||
|
/// By default `on_value` is `1` and `off_value` is `0`.
|
||||||
|
///
|
||||||
|
/// Other encoding values can be used by setting `on_value` and `off_value` to the desired values.
|
||||||
|
///
|
||||||
|
/// # Examples
|
||||||
|
///
|
||||||
|
/// ## One-hot encoding
|
||||||
|
///
|
||||||
|
/// ```rust
|
||||||
|
/// use candle::{Shape, Tensor, Device};
|
||||||
|
/// use candle_nn::encoding::one_hot;
|
||||||
|
///
|
||||||
|
/// let device = candle::Device::Cpu;
|
||||||
|
///
|
||||||
|
/// let indices = Tensor::new(vec![vec![0i64, 2], vec![1, -1]], &device).unwrap();
|
||||||
|
/// let depth = 4;
|
||||||
|
/// let one_hot = one_hot(indices, depth, 1f32, 0f32).unwrap();
|
||||||
|
///
|
||||||
|
/// let expected_matrix = [
|
||||||
|
/// [[1.0, 0.0, 0.0, 0.0], [0.0, 0.0, 1.0, 0.0]],
|
||||||
|
/// [[0.0, 1.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0]],
|
||||||
|
/// ];
|
||||||
|
///
|
||||||
|
/// assert_eq!(one_hot.shape(), &Shape::from((2, 2, depth)));
|
||||||
|
///
|
||||||
|
/// let matrix = one_hot.to_vec3::<f32>().unwrap();
|
||||||
|
///
|
||||||
|
/// assert_eq!(matrix, expected_matrix);
|
||||||
|
///```
|
||||||
|
/// ## One-cold Encoding
|
||||||
|
///
|
||||||
|
/// ```rust
|
||||||
|
/// use candle::{Shape, Tensor, Device};
|
||||||
|
/// use candle_nn::encoding::one_hot;
|
||||||
|
///
|
||||||
|
///
|
||||||
|
/// let device = candle::Device::Cpu;
|
||||||
|
/// let depth = 4;
|
||||||
|
/// let indices = Tensor::new(vec![vec![0u8, 2], vec![1, 3]], &device).unwrap();
|
||||||
|
/// let one_cold = one_hot(indices, depth, 0u8, 1u8).unwrap();
|
||||||
|
///
|
||||||
|
/// let expected_matrix = [[[0, 1, 1, 1], [1, 1, 0, 1]], [[1, 0, 1, 1], [1, 1, 1, 0]]];
|
||||||
|
///
|
||||||
|
/// assert_eq!(one_cold.shape(), &Shape::from((2, 2, depth)));
|
||||||
|
///
|
||||||
|
/// let matrix = one_cold.to_vec3::<u8>().unwrap();
|
||||||
|
///
|
||||||
|
/// assert_eq!(matrix, expected_matrix);
|
||||||
|
/// ```
|
||||||
|
///
|
||||||
|
///
|
||||||
|
/// # Bails
|
||||||
|
///
|
||||||
|
/// This method bails if:
|
||||||
|
/// - One of the index value is less than -1.
|
||||||
|
/// - One of the index value is greater than or equal to the depth value.
|
||||||
|
/// - The input data type is not `U8`, `U32`, or `I64`.
|
||||||
|
///
|
||||||
|
/// # API Design
|
||||||
|
///
|
||||||
|
/// The api design for this method is loosely based on the [TensorFlow One-Hot](https://www.tensorflow.org/api_docs/python/tf/one_hot) method.
|
||||||
|
pub fn one_hot<D: WithDType>(
|
||||||
|
indices: Tensor,
|
||||||
|
depth: usize,
|
||||||
|
on_value: D,
|
||||||
|
off_value: D,
|
||||||
|
) -> Result<Tensor> {
|
||||||
|
let mut target_shape = indices.dims().to_vec();
|
||||||
|
target_shape.push(depth);
|
||||||
|
let indices = indices.flatten_all()?;
|
||||||
|
let mut out = vec![off_value; depth * indices.elem_count()];
|
||||||
|
match indices.dtype() {
|
||||||
|
DType::U8 => {
|
||||||
|
let indices = indices.to_vec1::<u8>()?;
|
||||||
|
for (i, &index) in indices.iter().enumerate() {
|
||||||
|
set_at_index(index, i * depth, depth, &mut out, on_value)?;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
DType::U32 => {
|
||||||
|
let indices = indices.to_vec1::<u32>()?;
|
||||||
|
for (i, &index) in indices.iter().enumerate() {
|
||||||
|
set_at_index(index, i * depth, depth, &mut out, on_value)?;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
DType::I64 => {
|
||||||
|
let indices = indices.to_vec1::<i64>()?;
|
||||||
|
for (i, &index) in indices.iter().enumerate() {
|
||||||
|
set_at_index(index, i * depth, depth, &mut out, on_value)?;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
dtype => {
|
||||||
|
bail!("one_hot: unsupported data type {dtype:?}, expected U8, U32, or I64")
|
||||||
|
}
|
||||||
|
};
|
||||||
|
Tensor::from_vec(out, target_shape, indices.device())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn set_at_index<D: WithDType, I: Into<i64>>(
|
||||||
|
value: I,
|
||||||
|
offset: usize,
|
||||||
|
depth: usize,
|
||||||
|
v: &mut Vec<D>,
|
||||||
|
on_value: D,
|
||||||
|
) -> Result<()> {
|
||||||
|
let value = value.into();
|
||||||
|
// Skip for an entire row of off_values
|
||||||
|
if value == -1 {
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
|
if value < -1 {
|
||||||
|
bail!(
|
||||||
|
"one_hot: invalid negative index value {value}, expected a positive index value or -1"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
let value = value as usize;
|
||||||
|
if value >= depth {
|
||||||
|
bail!("one_hot: index value {value} exceeds depth {depth}")
|
||||||
|
}
|
||||||
|
let idx = offset + value;
|
||||||
|
if idx >= v.len() {
|
||||||
|
bail!("one_hot: index out of bounds {idx}, len {}", v.len());
|
||||||
|
}
|
||||||
|
v[idx] = on_value;
|
||||||
|
Ok(())
|
||||||
|
}
|
@ -2,6 +2,7 @@ pub mod activation;
|
|||||||
pub mod batch_norm;
|
pub mod batch_norm;
|
||||||
pub mod conv;
|
pub mod conv;
|
||||||
pub mod embedding;
|
pub mod embedding;
|
||||||
|
pub mod encoding;
|
||||||
pub mod func;
|
pub mod func;
|
||||||
pub mod group_norm;
|
pub mod group_norm;
|
||||||
pub mod init;
|
pub mod init;
|
||||||
|
@ -16,6 +16,8 @@ input = torch.randn(2, 5, 3, 4)
|
|||||||
output = m(input)
|
output = m(input)
|
||||||
print(input.flatten())
|
print(input.flatten())
|
||||||
print(output.flatten())
|
print(output.flatten())
|
||||||
|
print(m.running_mean)
|
||||||
|
print(m.running_var)
|
||||||
*/
|
*/
|
||||||
#[test]
|
#[test]
|
||||||
fn batch_norm() -> Result<()> {
|
fn batch_norm() -> Result<()> {
|
||||||
@ -37,7 +39,7 @@ fn batch_norm() -> Result<()> {
|
|||||||
1.4252, -0.9115, -0.1093, -0.3100, -0.6734, -1.4357, 0.9205,
|
1.4252, -0.9115, -0.1093, -0.3100, -0.6734, -1.4357, 0.9205,
|
||||||
];
|
];
|
||||||
let input = Tensor::new(&input, &Device::Cpu)?.reshape((2, 5, 3, 4))?;
|
let input = Tensor::new(&input, &Device::Cpu)?.reshape((2, 5, 3, 4))?;
|
||||||
let output = bn.forward_learning(&input)?;
|
let output = bn.forward_train(&input)?;
|
||||||
assert_eq!(output.dims(), &[2, 5, 3, 4]);
|
assert_eq!(output.dims(), &[2, 5, 3, 4]);
|
||||||
let output = output.flatten_all()?;
|
let output = output.flatten_all()?;
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
@ -65,11 +67,20 @@ fn batch_norm() -> Result<()> {
|
|||||||
Tensor::new(&[-1.5f32], &Device::Cpu)?.broadcast_as(5)?,
|
Tensor::new(&[-1.5f32], &Device::Cpu)?.broadcast_as(5)?,
|
||||||
1e-8,
|
1e-8,
|
||||||
)?;
|
)?;
|
||||||
let output2 = bn2.forward_learning(&input)?;
|
let output2 = bn2.forward_train(&input)?;
|
||||||
assert_eq!(output2.dims(), &[2, 5, 3, 4]);
|
assert_eq!(output2.dims(), &[2, 5, 3, 4]);
|
||||||
let output2 = output2.flatten_all()?;
|
let output2 = output2.flatten_all()?;
|
||||||
let diff2 = ((output2 - (output * 0.5)?)? + 1.5)?.sqr()?;
|
let diff2 = ((output2 - (output * 0.5)?)? + 1.5)?.sqr()?;
|
||||||
let sum_diff2 = diff2.sum_keepdim(0)?;
|
let sum_diff2 = diff2.sum_keepdim(0)?;
|
||||||
assert_eq!(test_utils::to_vec1_round(&sum_diff2, 4)?, &[0f32]);
|
assert_eq!(test_utils::to_vec1_round(&sum_diff2, 4)?, &[0f32]);
|
||||||
|
|
||||||
|
assert_eq!(
|
||||||
|
test_utils::to_vec1_round(bn.running_mean(), 4)?,
|
||||||
|
&[-0.0133, 0.0197, -0.0153, -0.0073, -0.0020]
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
test_utils::to_vec1_round(bn.running_var(), 4)?,
|
||||||
|
&[0.9972, 0.9842, 0.9956, 0.9866, 0.9898]
|
||||||
|
);
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
120
candle-nn/tests/one_hot.rs
Normal file
120
candle-nn/tests/one_hot.rs
Normal file
@ -0,0 +1,120 @@
|
|||||||
|
use candle::{Result, Shape, Tensor};
|
||||||
|
use candle_nn::encoding::one_hot;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_i64_one_hot() -> Result<()> {
|
||||||
|
let device = candle::Device::Cpu;
|
||||||
|
|
||||||
|
let indices = Tensor::new(vec![vec![0i64, 2], vec![1, -1]], &device)?;
|
||||||
|
let depth = 4;
|
||||||
|
|
||||||
|
let on_value = 1.0;
|
||||||
|
let off_value = 0.0;
|
||||||
|
|
||||||
|
let one_hot = one_hot::<f32>(indices, depth, on_value, off_value)?;
|
||||||
|
|
||||||
|
let expected_matrix = [
|
||||||
|
[[1., 0., 0., 0.], [0., 0., 1., 0.]],
|
||||||
|
[[0., 1., 0., 0.], [0., 0., 0., 0.]],
|
||||||
|
];
|
||||||
|
|
||||||
|
assert_eq!(one_hot.shape(), &Shape::from((2, 2, depth)));
|
||||||
|
|
||||||
|
let matrix = one_hot.to_vec3::<f32>()?;
|
||||||
|
|
||||||
|
assert_eq!(matrix, expected_matrix);
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_rank_3_one_hot() -> Result<()> {
|
||||||
|
let device = candle::Device::Cpu;
|
||||||
|
|
||||||
|
let indices = Tensor::new(
|
||||||
|
vec![
|
||||||
|
vec![vec![0i64, 1], vec![2, 3]],
|
||||||
|
vec![vec![3, 1], vec![1, -1]],
|
||||||
|
],
|
||||||
|
&device,
|
||||||
|
)?;
|
||||||
|
let depth = 4;
|
||||||
|
|
||||||
|
let on_value = 1.0;
|
||||||
|
let off_value = 0.0;
|
||||||
|
|
||||||
|
let one_hot = one_hot::<f32>(indices, depth, on_value, off_value)?;
|
||||||
|
|
||||||
|
let expected_matrix = Tensor::new(
|
||||||
|
vec![
|
||||||
|
vec![
|
||||||
|
vec![vec![1f32, 0., 0., 0.], vec![0., 1., 0., 0.]],
|
||||||
|
vec![vec![0., 0., 1., 0.], vec![0., 0., 0., 1.]],
|
||||||
|
],
|
||||||
|
vec![
|
||||||
|
vec![vec![0., 0., 0., 1.], vec![0., 1., 0., 0.]],
|
||||||
|
vec![vec![0., 1., 0., 0.], vec![0., 0., 0., 0.]],
|
||||||
|
],
|
||||||
|
],
|
||||||
|
&device,
|
||||||
|
)?;
|
||||||
|
|
||||||
|
assert_eq!(one_hot.shape(), expected_matrix.shape());
|
||||||
|
assert_eq!(one_hot.dims(), expected_matrix.dims());
|
||||||
|
|
||||||
|
let matrix = one_hot.get(1)?.to_vec3::<f32>()?;
|
||||||
|
let expected_matrix = expected_matrix.get(1)?.to_vec3::<f32>()?;
|
||||||
|
|
||||||
|
assert_eq!(matrix, expected_matrix);
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_u8_one_cold() -> Result<()> {
|
||||||
|
let device = candle::Device::Cpu;
|
||||||
|
let depth = 4;
|
||||||
|
let indices = Tensor::new(vec![vec![0i64, 2], vec![1, -1]], &device)?;
|
||||||
|
|
||||||
|
let on_value = 0u8;
|
||||||
|
let off_value = 1;
|
||||||
|
|
||||||
|
// Note that the method does not require the turbofish operator, as the type is inferred from the on_value.
|
||||||
|
let one_cold = one_hot(indices, depth, on_value, off_value)?;
|
||||||
|
|
||||||
|
let expected_matrix = [[[0, 1, 1, 1], [1, 1, 0, 1]], [[1, 0, 1, 1], [1, 1, 1, 1]]];
|
||||||
|
|
||||||
|
assert_eq!(one_cold.shape(), &Shape::from((2, 2, depth)));
|
||||||
|
|
||||||
|
let matrix = one_cold.to_vec3::<u8>()?;
|
||||||
|
|
||||||
|
assert_eq!(matrix, expected_matrix);
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_iter() -> Result<()> {
|
||||||
|
let device = candle::Device::Cpu;
|
||||||
|
let depth = 4;
|
||||||
|
let indices = Tensor::new(vec![vec![0i64, 2], vec![1, -1]], &device)?;
|
||||||
|
let matrix = indices.to_vec2::<i64>()?;
|
||||||
|
let (dim1, dim2) = indices.dims2()?;
|
||||||
|
|
||||||
|
let iter = (0..dim1).flat_map(|i| (0..dim2).map(move |j| (i, j)));
|
||||||
|
|
||||||
|
let mut v = vec![0; depth * dim1 * dim2];
|
||||||
|
|
||||||
|
for (i, j) in iter {
|
||||||
|
let idx = i * depth * dim2 + j * depth;
|
||||||
|
v[idx] = matrix[i][j];
|
||||||
|
}
|
||||||
|
|
||||||
|
for (i, row) in matrix.iter().enumerate() {
|
||||||
|
for (j, &value) in row.iter().enumerate() {
|
||||||
|
let idx = i * depth * dim2 + j * depth;
|
||||||
|
assert_eq!(v[idx], value);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
@ -40,8 +40,8 @@ fn block(dim: usize, kernel_size: usize, vb: VarBuilder) -> Result<impl Module>
|
|||||||
let conv2 = candle_nn::conv2d(dim, dim, 1, Default::default(), vb.pp(1))?;
|
let conv2 = candle_nn::conv2d(dim, dim, 1, Default::default(), vb.pp(1))?;
|
||||||
let bn2 = batch_norm(dim, 1e-5, vb.pp(3))?;
|
let bn2 = batch_norm(dim, 1e-5, vb.pp(3))?;
|
||||||
Ok(candle_nn::func(move |xs| {
|
Ok(candle_nn::func(move |xs| {
|
||||||
let ys = xs.apply(&conv1)?.gelu_erf()?.apply(&bn1)?;
|
let ys = xs.apply(&conv1)?.gelu_erf()?.apply_t(&bn1, false)?;
|
||||||
(xs + ys)?.apply(&conv2)?.gelu_erf()?.apply(&bn2)
|
(xs + ys)?.apply(&conv2)?.gelu_erf()?.apply_t(&bn2, false)
|
||||||
}))
|
}))
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -64,7 +64,7 @@ fn convmixer(
|
|||||||
.collect::<Result<Vec<_>>>()?;
|
.collect::<Result<Vec<_>>>()?;
|
||||||
let fc = candle_nn::linear(dim, nclasses, vb.pp(25))?;
|
let fc = candle_nn::linear(dim, nclasses, vb.pp(25))?;
|
||||||
Ok(candle_nn::func(move |xs| {
|
Ok(candle_nn::func(move |xs| {
|
||||||
let mut xs = xs.apply(&conv1)?.gelu_erf()?.apply(&bn1)?;
|
let mut xs = xs.apply(&conv1)?.gelu_erf()?.apply_t(&bn1, false)?;
|
||||||
for block in blocks.iter() {
|
for block in blocks.iter() {
|
||||||
xs = xs.apply(block)?
|
xs = xs.apply(block)?
|
||||||
}
|
}
|
||||||
|
@ -169,8 +169,7 @@ impl ConvNormActivation {
|
|||||||
|
|
||||||
impl Module for ConvNormActivation {
|
impl Module for ConvNormActivation {
|
||||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||||
let xs = self.conv2d.forward(xs)?;
|
let xs = self.conv2d.forward(xs)?.apply_t(&self.bn2d, false)?;
|
||||||
let xs = self.bn2d.forward(&xs)?;
|
|
||||||
if self.activation {
|
if self.activation {
|
||||||
swish(&xs)
|
swish(&xs)
|
||||||
} else {
|
} else {
|
||||||
|
@ -25,7 +25,7 @@ fn downsample(c_in: usize, c_out: usize, stride: usize, vb: VarBuilder) -> Resul
|
|||||||
if stride != 1 || c_in != c_out {
|
if stride != 1 || c_in != c_out {
|
||||||
let conv = conv2d(c_in, c_out, 1, 0, stride, vb.pp(0))?;
|
let conv = conv2d(c_in, c_out, 1, 0, stride, vb.pp(0))?;
|
||||||
let bn = batch_norm(c_out, 1e-5, vb.pp(1))?;
|
let bn = batch_norm(c_out, 1e-5, vb.pp(1))?;
|
||||||
Ok(Func::new(move |xs| xs.apply(&conv)?.apply(&bn)))
|
Ok(Func::new(move |xs| xs.apply(&conv)?.apply_t(&bn, false)))
|
||||||
} else {
|
} else {
|
||||||
Ok(Func::new(|xs| Ok(xs.clone())))
|
Ok(Func::new(|xs| Ok(xs.clone())))
|
||||||
}
|
}
|
||||||
@ -40,10 +40,10 @@ fn basic_block(c_in: usize, c_out: usize, stride: usize, vb: VarBuilder) -> Resu
|
|||||||
Ok(Func::new(move |xs| {
|
Ok(Func::new(move |xs| {
|
||||||
let ys = xs
|
let ys = xs
|
||||||
.apply(&conv1)?
|
.apply(&conv1)?
|
||||||
.apply(&bn1)?
|
.apply_t(&bn1, false)?
|
||||||
.relu()?
|
.relu()?
|
||||||
.apply(&conv2)?
|
.apply(&conv2)?
|
||||||
.apply(&bn2)?;
|
.apply_t(&bn2, false)?;
|
||||||
(xs.apply(&downsample)? + ys)?.relu()
|
(xs.apply(&downsample)? + ys)?.relu()
|
||||||
}))
|
}))
|
||||||
}
|
}
|
||||||
@ -94,7 +94,7 @@ fn resnet(
|
|||||||
Ok(Func::new(move |xs| {
|
Ok(Func::new(move |xs| {
|
||||||
let xs = xs
|
let xs = xs
|
||||||
.apply(&conv1)?
|
.apply(&conv1)?
|
||||||
.apply(&bn1)?
|
.apply_t(&bn1, false)?
|
||||||
.relu()?
|
.relu()?
|
||||||
.pad_with_same(D::Minus1, 1, 1)?
|
.pad_with_same(D::Minus1, 1, 1)?
|
||||||
.pad_with_same(D::Minus2, 1, 1)?
|
.pad_with_same(D::Minus2, 1, 1)?
|
||||||
@ -149,13 +149,13 @@ fn bottleneck_block(
|
|||||||
Ok(Func::new(move |xs| {
|
Ok(Func::new(move |xs| {
|
||||||
let ys = xs
|
let ys = xs
|
||||||
.apply(&conv1)?
|
.apply(&conv1)?
|
||||||
.apply(&bn1)?
|
.apply_t(&bn1, false)?
|
||||||
.relu()?
|
.relu()?
|
||||||
.apply(&conv2)?
|
.apply(&conv2)?
|
||||||
.apply(&bn2)?
|
.apply_t(&bn2, false)?
|
||||||
.relu()?
|
.relu()?
|
||||||
.apply(&conv3)?
|
.apply(&conv3)?
|
||||||
.apply(&bn3)?;
|
.apply_t(&bn3, false)?;
|
||||||
(xs.apply(&downsample)? + ys)?.relu()
|
(xs.apply(&downsample)? + ys)?.relu()
|
||||||
}))
|
}))
|
||||||
}
|
}
|
||||||
@ -206,7 +206,7 @@ fn bottleneck_resnet(
|
|||||||
Ok(Func::new(move |xs| {
|
Ok(Func::new(move |xs| {
|
||||||
let xs = xs
|
let xs = xs
|
||||||
.apply(&conv1)?
|
.apply(&conv1)?
|
||||||
.apply(&bn1)?
|
.apply_t(&bn1, false)?
|
||||||
.relu()?
|
.relu()?
|
||||||
.pad_with_same(D::Minus1, 1, 1)?
|
.pad_with_same(D::Minus1, 1, 1)?
|
||||||
.pad_with_same(D::Minus2, 1, 1)?
|
.pad_with_same(D::Minus2, 1, 1)?
|
||||||
|
@ -28,7 +28,7 @@ impl Conv2dBN {
|
|||||||
impl Module for Conv2dBN {
|
impl Module for Conv2dBN {
|
||||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||||
let _enter = self.span.enter();
|
let _enter = self.span.enter();
|
||||||
xs.apply(&self.c)?.apply(&self.bn)
|
xs.apply(&self.c)?.apply_t(&self.bn, false)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -185,7 +185,7 @@ impl PaellaVQ {
|
|||||||
xs = xs.apply(&down_block.1)?
|
xs = xs.apply(&down_block.1)?
|
||||||
}
|
}
|
||||||
xs.apply(&self.down_blocks_conv)?
|
xs.apply(&self.down_blocks_conv)?
|
||||||
.apply(&self.down_blocks_bn)
|
.apply_t(&self.down_blocks_bn, false)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn decode(&self, xs: &Tensor) -> Result<Tensor> {
|
pub fn decode(&self, xs: &Tensor) -> Result<Tensor> {
|
||||||
|
@ -107,8 +107,7 @@ impl ConvBlock {
|
|||||||
|
|
||||||
impl Module for ConvBlock {
|
impl Module for ConvBlock {
|
||||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||||
let xs = self.conv.forward(xs)?;
|
let xs = self.conv.forward(xs)?.apply_t(&self.bn, false)?;
|
||||||
let xs = self.bn.forward(&xs)?;
|
|
||||||
candle_nn::ops::silu(&xs)
|
candle_nn::ops::silu(&xs)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user