This commit is contained in:
Laurent
2024-01-07 20:27:33 +01:00
57 changed files with 1309 additions and 1409 deletions

View File

@ -31,6 +31,14 @@ license = "MIT OR Apache-2.0"
accelerate-src = { version = "0.3.2" } accelerate-src = { version = "0.3.2" }
anyhow = { version = "1", features = ["backtrace"] } anyhow = { version = "1", features = ["backtrace"] }
byteorder = "1.4.3" byteorder = "1.4.3"
candle = { path = "./candle-core", package = "candle-core" }
candle-datasets = { path = "./candle-datasets" }
candle-flash-attn = { path = "./candle-flash-attn" }
candle-kernels = { path = "./candle-kernels" }
candle-metal-kernels = { path = "./candle-metal-kernels" }
candle-nn = { path = "./candle-nn" }
candle-onnx = { path = "./candle-onnx" }
candle-transformers = { path = "./candle-transformers" }
clap = { version = "4.2.4", features = ["derive"] } clap = { version = "4.2.4", features = ["derive"] }
criterion = { version = "0.5.1", default-features=false } criterion = { version = "0.5.1", default-features=false }
cudarc = { version = "0.9.14", features = ["f16"] } cudarc = { version = "0.9.14", features = ["f16"] }

View File

@ -158,6 +158,7 @@ And then head over to
- [`candle-ext`](https://github.com/mokeyish/candle-ext): An extension library to Candle that provides PyTorch functions not currently available in Candle. - [`candle-ext`](https://github.com/mokeyish/candle-ext): An extension library to Candle that provides PyTorch functions not currently available in Candle.
- [`kalosm`](https://github.com/floneum/floneum/tree/master/interfaces/kalosm): A multi-modal meta-framework in Rust for interfacing with local pre-trained models with support for controlled generation, custom samplers, in-memory vector databases, audio transcription, and more. - [`kalosm`](https://github.com/floneum/floneum/tree/master/interfaces/kalosm): A multi-modal meta-framework in Rust for interfacing with local pre-trained models with support for controlled generation, custom samplers, in-memory vector databases, audio transcription, and more.
- [`candle-sampling`](https://github.com/EricLBuehler/candle-sampling): Sampling techniques for Candle. - [`candle-sampling`](https://github.com/EricLBuehler/candle-sampling): Sampling techniques for Candle.
- [`gpt-from-scratch-rs`](https://github.com/jeroenvlek/gpt-from-scratch-rs): A port of Andrej Karpathy's _Let's build GPT_ tutorial on YouTube showcasing the Candle API on a toy problem.
If you have an addition to this list, please submit a pull request. If you have an addition to this list, please submit a pull request.

View File

@ -11,11 +11,11 @@ readme = "README.md"
[dependencies] [dependencies]
accelerate-src = { workspace = true, optional = true } accelerate-src = { workspace = true, optional = true }
candle = { path = "../candle-core", version = "0.3.3", package = "candle-core" } candle = { workspace = true }
candle-datasets = { path = "../candle-datasets", version = "0.3.3" } candle-datasets = { workspace = true }
candle-nn = { path = "../candle-nn", version = "0.3.3" } candle-nn = { workspace = true }
candle-transformers = { path = "../candle-transformers", version = "0.3.3" } candle-transformers = { workspace = true }
candle-flash-attn = { path = "../candle-flash-attn", version = "0.3.3", optional = true } candle-flash-attn = { workspace = true, optional = true }
safetensors = { workspace = true } safetensors = { workspace = true }
serde = { workspace = true } serde = { workspace = true }
serde_json = { workspace = true } serde_json = { workspace = true }

View File

@ -12,8 +12,8 @@ readme = "README.md"
[dependencies] [dependencies]
accelerate-src = { workspace = true, optional = true } accelerate-src = { workspace = true, optional = true }
byteorder = { workspace = true } byteorder = { workspace = true }
candle-kernels = { path = "../candle-kernels", version = "0.3.3", optional = true } candle-kernels = { workspace = true, optional = true }
candle-metal-kernels = { path = "../candle-metal-kernels", version = "0.3.3", optional = true } candle-metal-kernels = { workspace = true, optional = true }
metal = { workspace = true, optional = true} metal = { workspace = true, optional = true}
cudarc = { workspace = true, optional = true } cudarc = { workspace = true, optional = true }
gemm = { workspace = true } gemm = { workspace = true }

View File

@ -12,6 +12,14 @@ use core::arch::arm::*;
#[cfg(target_arch = "aarch64")] #[cfg(target_arch = "aarch64")]
use core::arch::aarch64::*; use core::arch::aarch64::*;
#[inline(always)]
unsafe fn vdotq_s32(a: int8x16_t, b: int8x16_t) -> int32x4_t {
// TODO: dotprod
let p0 = vmull_s8(vget_low_s8(a), vget_low_s8(b));
let p1 = vmull_s8(vget_high_s8(a), vget_high_s8(b));
vaddq_s32(vpaddlq_s16(p0), vpaddlq_s16(p1))
}
#[inline(always)] #[inline(always)]
pub(crate) fn vec_dot_q4_0_q8_0(n: usize, xs: &[BlockQ4_0], ys: &[BlockQ8_0]) -> Result<f32> { pub(crate) fn vec_dot_q4_0_q8_0(n: usize, xs: &[BlockQ4_0], ys: &[BlockQ8_0]) -> Result<f32> {
let qk = QK8_0; let qk = QK8_0;
@ -43,15 +51,8 @@ pub(crate) fn vec_dot_q4_0_q8_0(n: usize, xs: &[BlockQ4_0], ys: &[BlockQ8_0]) ->
let v1_0l = vld1q_s8(y0.qs.as_ptr()); let v1_0l = vld1q_s8(y0.qs.as_ptr());
let v1_0h = vld1q_s8(y0.qs.as_ptr().add(16)); let v1_0h = vld1q_s8(y0.qs.as_ptr().add(16));
// TODO: Support dotprod when it's available outside of nightly. let pl0 = vdotq_s32(v0_0ls, v1_0l);
let pl0l = vmull_s8(vget_low_s8(v0_0ls), vget_low_s8(v1_0l)); let ph0 = vdotq_s32(v0_0hs, v1_0h);
let pl0h = vmull_s8(vget_high_s8(v0_0ls), vget_high_s8(v1_0l));
let ph0l = vmull_s8(vget_low_s8(v0_0hs), vget_low_s8(v1_0h));
let ph0h = vmull_s8(vget_high_s8(v0_0hs), vget_high_s8(v1_0h));
let pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h));
let ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h));
sumv0 = vmlaq_n_f32( sumv0 = vmlaq_n_f32(
sumv0, sumv0,
vcvtq_f32_s32(vaddq_s32(pl0, ph0)), vcvtq_f32_s32(vaddq_s32(pl0, ph0)),
@ -82,14 +83,8 @@ pub(crate) fn vec_dot_q8_0_q8_0(n: usize, xs: &[BlockQ8_0], ys: &[BlockQ8_0]) ->
let y0_0 = vld1q_s8(y0.qs.as_ptr()); let y0_0 = vld1q_s8(y0.qs.as_ptr());
let y0_1 = vld1q_s8(y0.qs.as_ptr().add(16)); let y0_1 = vld1q_s8(y0.qs.as_ptr().add(16));
// TODO dotprod once this is the intrinsics are. let p0 = vdotq_s32(x0_0, y0_0);
let p0_0 = vmull_s8(vget_low_s8(x0_0), vget_low_s8(y0_0)); let p1 = vdotq_s32(x0_1, y0_1);
let p0_1 = vmull_s8(vget_high_s8(x0_0), vget_high_s8(y0_0));
let p0_2 = vmull_s8(vget_low_s8(x0_1), vget_low_s8(y0_1));
let p0_3 = vmull_s8(vget_high_s8(x0_1), vget_high_s8(y0_1));
let p0 = vaddq_s32(vpaddlq_s16(p0_0), vpaddlq_s16(p0_1));
let p1 = vaddq_s32(vpaddlq_s16(p0_2), vpaddlq_s16(p0_3));
sumv0 = vmlaq_n_f32( sumv0 = vmlaq_n_f32(
sumv0, sumv0,
@ -118,10 +113,7 @@ pub(crate) fn vec_dot_q8k_q8k(n: usize, xs: &[BlockQ8K], ys: &[BlockQ8K]) -> Res
for i in (0..QK_K).step_by(16) { for i in (0..QK_K).step_by(16) {
let xs = vld1q_s8(xs.add(i)); let xs = vld1q_s8(xs.add(i));
let ys = vld1q_s8(ys.add(i)); let ys = vld1q_s8(ys.add(i));
let xy_lo = vmull_s8(vget_low_s8(xs), vget_low_s8(ys)); let xy = vdotq_s32(xs, ys);
let xy_up = vmull_s8(vget_high_s8(xs), vget_high_s8(ys));
let xy = vaddq_s32(vpaddlq_s16(xy_lo), vpaddlq_s16(xy_up));
sum_i = vaddq_s32(sum_i, xy) sum_i = vaddq_s32(sum_i, xy)
} }
sumf += vaddvq_s32(sum_i) as f32 * scale sumf += vaddvq_s32(sum_i) as f32 * scale
@ -191,30 +183,16 @@ pub(crate) fn vec_dot_q6k_q8k(n: usize, xs: &[BlockQ6K], ys: &[BlockQ8K]) -> Res
let q6bytes_2 = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.2, m4b), q6h_2)); let q6bytes_2 = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.2, m4b), q6h_2));
let q6bytes_3 = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.3, m4b), q6h_3)); let q6bytes_3 = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.3, m4b), q6h_3));
// TODO: dotprod let p0 = vdotq_s32(q6bytes_0, q8bytes.0);
let p1 = vdotq_s32(q6bytes_1, q8bytes.1);
let p0 = vaddq_s16(
vmull_s8(vget_low_s8(q6bytes_0), vget_low_s8(q8bytes.0)),
vmull_s8(vget_high_s8(q6bytes_0), vget_high_s8(q8bytes.0)),
);
let p1 = vaddq_s16(
vmull_s8(vget_low_s8(q6bytes_1), vget_low_s8(q8bytes.1)),
vmull_s8(vget_high_s8(q6bytes_1), vget_high_s8(q8bytes.1)),
);
let (scale0, scale1) = (*scale as i32, *scale.add(1) as i32); let (scale0, scale1) = (*scale as i32, *scale.add(1) as i32);
isum += vaddvq_s16(p0) as i32 * scale0 + vaddvq_s16(p1) as i32 * scale1; isum += vaddvq_s32(p0) * scale0 + vaddvq_s32(p1) * scale1;
scale = scale.add(2); scale = scale.add(2);
let p2 = vaddq_s16( let p2 = vdotq_s32(q6bytes_2, q8bytes.2);
vmull_s8(vget_low_s8(q6bytes_2), vget_low_s8(q8bytes.2)), let p3 = vdotq_s32(q6bytes_3, q8bytes.3);
vmull_s8(vget_high_s8(q6bytes_2), vget_high_s8(q8bytes.2)),
);
let p3 = vaddq_s16(
vmull_s8(vget_low_s8(q6bytes_3), vget_low_s8(q8bytes.3)),
vmull_s8(vget_high_s8(q6bytes_3), vget_high_s8(q8bytes.3)),
);
let (scale0, scale1) = (*scale as i32, *scale.add(1) as i32); let (scale0, scale1) = (*scale as i32, *scale.add(1) as i32);
isum += vaddvq_s16(p2) as i32 * scale0 + vaddvq_s16(p3) as i32 * scale1; isum += vaddvq_s32(p2) * scale0 + vaddvq_s32(p3) * scale1;
scale = scale.add(2); scale = scale.add(2);
let q8bytes = vld1q_s8_x4(q8); let q8bytes = vld1q_s8_x4(q8);
@ -234,29 +212,16 @@ pub(crate) fn vec_dot_q6k_q8k(n: usize, xs: &[BlockQ6K], ys: &[BlockQ8K]) -> Res
let q6bytes_2 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.2, 4), q6h_2)); let q6bytes_2 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.2, 4), q6h_2));
let q6bytes_3 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.3, 4), q6h_3)); let q6bytes_3 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.3, 4), q6h_3));
// TODO: dotprod case. let p0 = vdotq_s32(q6bytes_0, q8bytes.0);
let p0 = vaddq_s16( let p1 = vdotq_s32(q6bytes_1, q8bytes.1);
vmull_s8(vget_low_s8(q6bytes_0), vget_low_s8(q8bytes.0)),
vmull_s8(vget_high_s8(q6bytes_0), vget_high_s8(q8bytes.0)),
);
let p1 = vaddq_s16(
vmull_s8(vget_low_s8(q6bytes_1), vget_low_s8(q8bytes.1)),
vmull_s8(vget_high_s8(q6bytes_1), vget_high_s8(q8bytes.1)),
);
let (scale0, scale1) = (*scale as i32, *scale.add(1) as i32); let (scale0, scale1) = (*scale as i32, *scale.add(1) as i32);
isum += vaddvq_s16(p0) as i32 * scale0 + vaddvq_s16(p1) as i32 * scale1; isum += vaddvq_s32(p0) * scale0 + vaddvq_s32(p1) * scale1;
scale = scale.add(2); scale = scale.add(2);
let p2 = vaddq_s16( let p2 = vdotq_s32(q6bytes_2, q8bytes.2);
vmull_s8(vget_low_s8(q6bytes_2), vget_low_s8(q8bytes.2)), let p3 = vdotq_s32(q6bytes_3, q8bytes.3);
vmull_s8(vget_high_s8(q6bytes_2), vget_high_s8(q8bytes.2)),
);
let p3 = vaddq_s16(
vmull_s8(vget_low_s8(q6bytes_3), vget_low_s8(q8bytes.3)),
vmull_s8(vget_high_s8(q6bytes_3), vget_high_s8(q8bytes.3)),
);
let (scale0, scale1) = (*scale as i32, *scale.add(1) as i32); let (scale0, scale1) = (*scale as i32, *scale.add(1) as i32);
isum += vaddvq_s16(p2) as i32 * scale0 + vaddvq_s16(p3) as i32 * scale1; isum += vaddvq_s32(p2) * scale0 + vaddvq_s32(p3) * scale1;
scale = scale.add(2); scale = scale.add(2);
} }
sum += d_all * y.d * ((isum - 32 * isum_mins) as f32); sum += d_all * y.d * ((isum - 32 * isum_mins) as f32);
@ -333,28 +298,14 @@ pub(crate) fn vec_dot_q5k_q8k(n: usize, xs: &[BlockQ5K], ys: &[BlockQ8K]) -> Res
let q5bytes_2 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q5bits.0, 4), q5h_2)); let q5bytes_2 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q5bits.0, 4), q5h_2));
let q5bytes_3 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q5bits.1, 4), q5h_3)); let q5bytes_3 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q5bits.1, 4), q5h_3));
// TODO: dotprod let p0 = vdotq_s32(q5bytes_0, q8bytes.0);
let p1 = vdotq_s32(q5bytes_1, q8bytes.1);
let p0 = vaddq_s16( sumi += vaddvq_s32(vaddq_s32(p0, p1)) * *scales as i32;
vmull_s8(vget_low_s8(q5bytes_0), vget_low_s8(q8bytes.0)),
vmull_s8(vget_high_s8(q5bytes_0), vget_high_s8(q8bytes.0)),
);
let p1 = vaddq_s16(
vmull_s8(vget_low_s8(q5bytes_1), vget_low_s8(q8bytes.1)),
vmull_s8(vget_high_s8(q5bytes_1), vget_high_s8(q8bytes.1)),
);
sumi += vaddvq_s16(vaddq_s16(p0, p1)) as i32 * *scales as i32;
scales = scales.add(1); scales = scales.add(1);
let p2 = vaddq_s16( let p2 = vdotq_s32(q5bytes_2, q8bytes.2);
vmull_s8(vget_low_s8(q5bytes_2), vget_low_s8(q8bytes.2)), let p3 = vdotq_s32(q5bytes_3, q8bytes.3);
vmull_s8(vget_high_s8(q5bytes_2), vget_high_s8(q8bytes.2)), sumi += vaddvq_s32(vaddq_s32(p2, p3)) * *scales as i32;
);
let p3 = vaddq_s16(
vmull_s8(vget_low_s8(q5bytes_3), vget_low_s8(q8bytes.3)),
vmull_s8(vget_high_s8(q5bytes_3), vget_high_s8(q8bytes.3)),
);
sumi += vaddvq_s16(vaddq_s16(p2, p3)) as i32 * *scales as i32;
scales = scales.add(1); scales = scales.add(1);
} }
sumf += d * sumi as f32 - dmin * sumi_mins as f32; sumf += d * sumi as f32 - dmin * sumi_mins as f32;
@ -417,22 +368,15 @@ pub(crate) fn vec_dot_q4k_q8k(n: usize, xs: &[BlockQ4K], ys: &[BlockQ8K]) -> Res
for j in 0..QK_K / 64 { for j in 0..QK_K / 64 {
let q4bits = vld1q_u8_x2(q4); let q4bits = vld1q_u8_x2(q4);
q4 = q4.add(32); q4 = q4.add(32);
// TODO: dotprod
let q8bytes = vld1q_s8_x2(q8); let q8bytes = vld1q_s8_x2(q8);
q8 = q8.add(32); q8 = q8.add(32);
let q4bytes = int8x16x2_t( let q4bytes = int8x16x2_t(
vreinterpretq_s8_u8(vandq_u8(q4bits.0, m4b)), vreinterpretq_s8_u8(vandq_u8(q4bits.0, m4b)),
vreinterpretq_s8_u8(vandq_u8(q4bits.1, m4b)), vreinterpretq_s8_u8(vandq_u8(q4bits.1, m4b)),
); );
let p0 = vaddq_s16( let p0 = vdotq_s32(q4bytes.0, q8bytes.0);
vmull_s8(vget_low_s8(q4bytes.0), vget_low_s8(q8bytes.0)), let p1 = vdotq_s32(q4bytes.1, q8bytes.1);
vmull_s8(vget_high_s8(q4bytes.0), vget_high_s8(q8bytes.0)), sumi1 += vaddvq_s32(vaddq_s32(p0, p1)) * scales[2 * j] as i32;
);
let p1 = vaddq_s16(
vmull_s8(vget_low_s8(q4bytes.1), vget_low_s8(q8bytes.1)),
vmull_s8(vget_high_s8(q4bytes.1), vget_high_s8(q8bytes.1)),
);
sumi1 += vaddvq_s16(vaddq_s16(p0, p1)) as i32 * scales[2 * j] as i32;
let q8bytes = vld1q_s8_x2(q8); let q8bytes = vld1q_s8_x2(q8);
q8 = q8.add(32); q8 = q8.add(32);
@ -440,15 +384,9 @@ pub(crate) fn vec_dot_q4k_q8k(n: usize, xs: &[BlockQ4K], ys: &[BlockQ8K]) -> Res
vreinterpretq_s8_u8(vshrq_n_u8(q4bits.0, 4)), vreinterpretq_s8_u8(vshrq_n_u8(q4bits.0, 4)),
vreinterpretq_s8_u8(vshrq_n_u8(q4bits.1, 4)), vreinterpretq_s8_u8(vshrq_n_u8(q4bits.1, 4)),
); );
let p2 = vaddq_s16( let p2 = vdotq_s32(q4bytes.0, q8bytes.0);
vmull_s8(vget_low_s8(q4bytes.0), vget_low_s8(q8bytes.0)), let p3 = vdotq_s32(q4bytes.1, q8bytes.1);
vmull_s8(vget_high_s8(q4bytes.0), vget_high_s8(q8bytes.0)), sumi2 += vaddvq_s32(vaddq_s32(p2, p3)) * scales[2 * j + 1] as i32;
);
let p3 = vaddq_s16(
vmull_s8(vget_low_s8(q4bytes.1), vget_low_s8(q8bytes.1)),
vmull_s8(vget_high_s8(q4bytes.1), vget_high_s8(q8bytes.1)),
);
sumi2 += vaddvq_s16(vaddq_s16(p2, p3)) as i32 * scales[2 * j + 1] as i32;
} }
sumf += d * (sumi1 + sumi2) as f32; sumf += d * (sumi1 + sumi2) as f32;
} }
@ -526,27 +464,14 @@ pub(crate) fn vec_dot_q3k_q8k(n: usize, xs: &[BlockQ3K], ys: &[BlockQ8K]) -> Res
vreinterpretq_s8_u8(q3h_3), vreinterpretq_s8_u8(q3h_3),
); );
// TODO: dotprod let p0 = vdotq_s32(q3bytes_0, q8bytes_1.0);
let p0 = vaddq_s16( let p1 = vdotq_s32(q3bytes_1, q8bytes_1.1);
vmull_s8(vget_low_s8(q3bytes_0), vget_low_s8(q8bytes_1.0)), let p2 = vdotq_s32(q3bytes_2, q8bytes_1.2);
vmull_s8(vget_high_s8(q3bytes_0), vget_high_s8(q8bytes_1.0)), let p3 = vdotq_s32(q3bytes_3, q8bytes_1.3);
); isum += vaddvq_s32(p0) * *scale as i32
let p1 = vaddq_s16( + vaddvq_s32(p1) * *scale.add(1) as i32
vmull_s8(vget_low_s8(q3bytes_1), vget_low_s8(q8bytes_1.1)), + vaddvq_s32(p2) * *scale.add(2) as i32
vmull_s8(vget_high_s8(q3bytes_1), vget_high_s8(q8bytes_1.1)), + vaddvq_s32(p3) * *scale.add(3) as i32;
);
let p2 = vaddq_s16(
vmull_s8(vget_low_s8(q3bytes_2), vget_low_s8(q8bytes_1.2)),
vmull_s8(vget_high_s8(q3bytes_2), vget_high_s8(q8bytes_1.2)),
);
let p3 = vaddq_s16(
vmull_s8(vget_low_s8(q3bytes_3), vget_low_s8(q8bytes_1.3)),
vmull_s8(vget_high_s8(q3bytes_3), vget_high_s8(q8bytes_1.3)),
);
isum += vaddvq_s16(p0) as i32 * *scale as i32
+ vaddvq_s16(p1) as i32 * *scale.add(1) as i32
+ vaddvq_s16(p2) as i32 * *scale.add(2) as i32
+ vaddvq_s16(p3) as i32 * *scale.add(3) as i32;
scale = scale.add(4); scale = scale.add(4);
let q3h_0 = vbicq_u8(m2, qhbits.0); let q3h_0 = vbicq_u8(m2, qhbits.0);
@ -571,27 +496,14 @@ pub(crate) fn vec_dot_q3k_q8k(n: usize, xs: &[BlockQ3K], ys: &[BlockQ8K]) -> Res
vreinterpretq_s8_u8(q3h_3), vreinterpretq_s8_u8(q3h_3),
); );
// TODO: dotprod let p0 = vdotq_s32(q3bytes_0, q8bytes_2.0);
let p0 = vaddq_s16( let p1 = vdotq_s32(q3bytes_1, q8bytes_2.1);
vmull_s8(vget_low_s8(q3bytes_0), vget_low_s8(q8bytes_2.0)), let p2 = vdotq_s32(q3bytes_2, q8bytes_2.2);
vmull_s8(vget_high_s8(q3bytes_0), vget_high_s8(q8bytes_2.0)), let p3 = vdotq_s32(q3bytes_3, q8bytes_2.3);
); isum += vaddvq_s32(p0) * *scale as i32
let p1 = vaddq_s16( + vaddvq_s32(p1) * *scale.add(1) as i32
vmull_s8(vget_low_s8(q3bytes_1), vget_low_s8(q8bytes_2.1)), + vaddvq_s32(p2) * *scale.add(2) as i32
vmull_s8(vget_high_s8(q3bytes_1), vget_high_s8(q8bytes_2.1)), + vaddvq_s32(p3) * *scale.add(3) as i32;
);
let p2 = vaddq_s16(
vmull_s8(vget_low_s8(q3bytes_2), vget_low_s8(q8bytes_2.2)),
vmull_s8(vget_high_s8(q3bytes_2), vget_high_s8(q8bytes_2.2)),
);
let p3 = vaddq_s16(
vmull_s8(vget_low_s8(q3bytes_3), vget_low_s8(q8bytes_2.3)),
vmull_s8(vget_high_s8(q3bytes_3), vget_high_s8(q8bytes_2.3)),
);
isum += vaddvq_s16(p0) as i32 * *scale as i32
+ vaddvq_s16(p1) as i32 * *scale.add(1) as i32
+ vaddvq_s16(p2) as i32 * *scale.add(2) as i32
+ vaddvq_s16(p3) as i32 * *scale.add(3) as i32;
scale = scale.add(4); scale = scale.add(4);
if j == 0 { if j == 0 {
@ -649,7 +561,6 @@ pub(crate) fn vec_dot_q2k_q8k(n: usize, xs: &[BlockQ2K], ys: &[BlockQ8K]) -> Res
let mut is = 0usize; let mut is = 0usize;
// TODO: dotprod // TODO: dotprod
for _j in 0..QK_K / 128 { for _j in 0..QK_K / 128 {
let q2bits = vld1q_u8_x2(q2); let q2bits = vld1q_u8_x2(q2);
q2 = q2.add(32); q2 = q2.add(32);
@ -696,14 +607,7 @@ unsafe fn multiply_accum_with_scale(
q2bytes: int8x16x2_t, q2bytes: int8x16x2_t,
q8bytes: int8x16x2_t, q8bytes: int8x16x2_t,
) -> i32 { ) -> i32 {
let p1 = vaddq_s16( let p1 = vdotq_s32(q2bytes.0, q8bytes.0);
vmull_s8(vget_low_s8(q2bytes.0), vget_low_s8(q8bytes.0)), let p2 = vdotq_s32(q2bytes.1, q8bytes.1);
vmull_s8(vget_high_s8(q2bytes.0), vget_high_s8(q8bytes.0)), vaddvq_s32(p1) * aux[is + index] as i32 + vaddvq_s32(p2) * aux[is + 1 + index] as i32
);
let p2 = vaddq_s16(
vmull_s8(vget_low_s8(q2bytes.1), vget_low_s8(q8bytes.1)),
vmull_s8(vget_high_s8(q2bytes.1), vget_high_s8(q8bytes.1)),
);
vaddvq_s16(p1) as i32 * aux[is + index] as i32
+ vaddvq_s16(p2) as i32 * aux[is + 1 + index] as i32
} }

View File

@ -1,4 +1,5 @@
use candle_core::{ use candle_core::{
bail,
quantized::{self, GgmlDType}, quantized::{self, GgmlDType},
test_utils::to_vec2_round, test_utils::to_vec2_round,
Device, Module, Result, Tensor, Device, Module, Result, Tensor,
@ -265,7 +266,8 @@ fn compare_with_error(values: &[f32], expected: &[f32], tolerance: f32) {
} }
} }
/// Creates a vector simillarly to the one used in GGML unit tests: https://github.com/ggerganov/llama.cpp/blob/master/tests/test-quantize-fns.cpp#L26-L30 /// Creates a vector similar to the ones used in GGML unit tests:
/// https://github.com/ggerganov/llama.cpp/blob/master/tests/test-quantize-fns.cpp#L26-L30
fn create_ggml_like_vector(offset: f32) -> Vec<f32> { fn create_ggml_like_vector(offset: f32) -> Vec<f32> {
(0..GGML_TEST_SIZE) (0..GGML_TEST_SIZE)
.map(|i| 0.1 + 2.0 * (i as f32 + offset).cos()) .map(|i| 0.1 + 2.0 * (i as f32 + offset).cos())
@ -284,14 +286,15 @@ fn calculate_rmse(a: &[f32], b: &[f32]) -> f32 {
sum / a.len() as f32 sum / a.len() as f32
} }
/// Mirrores the GGML quanitzation unit test: https://github.com/ggerganov/llama.cpp/blob/master/tests/test-quantize-fns.cpp#L43-L50 /// Similar to the GGML quantization unit test:
/// https://github.com/ggerganov/llama.cpp/blob/master/tests/test-quantize-fns.cpp#L43-L50
fn ggml_quantization_error_test<T: GgmlType>(max_error: f32) -> Result<()> { fn ggml_quantization_error_test<T: GgmlType>(max_error: f32) -> Result<()> {
let src = create_ggml_like_vector(0.0); let src = create_ggml_like_vector(0.0);
let mut dst = vec![0.0; GGML_TEST_SIZE]; let mut dst = vec![0.0; GGML_TEST_SIZE];
let _quant = quantize_roundtrip::<T>(src.as_slice(), dst.as_mut_slice())?; let _quant = quantize_roundtrip::<T>(src.as_slice(), dst.as_mut_slice())?;
let error = calculate_rmse(src.as_slice(), dst.as_slice()); let error = calculate_rmse(src.as_slice(), dst.as_slice());
if error > max_error { if error > max_error {
candle_core::bail!( bail!(
"Quantization error {} exceeds max error {}", "Quantization error {} exceeds max error {}",
error, error,
max_error max_error
@ -487,54 +490,66 @@ fn ggml_reference_matmul_error(dtype: GgmlDType) -> Result<f32> {
GgmlDType::Q5K => 0.000740, GgmlDType::Q5K => 0.000740,
GgmlDType::Q6K => 0.000952, GgmlDType::Q6K => 0.000952,
GgmlDType::Q4_0 => 0.001143, GgmlDType::Q4_0 => 0.001143,
GgmlDType::Q4_1 => 0.007784, GgmlDType::Q4_1 => 0.008,
GgmlDType::Q5_0 => 0.001353, GgmlDType::Q5_0 => 0.001353,
GgmlDType::Q5_1 => 0.001363, GgmlDType::Q5_1 => 0.00149,
GgmlDType::Q8_0 => 0.000092, GgmlDType::Q8_0 => 0.000092,
// Not from the ggml repo. // Not from the ggml repo.
GgmlDType::Q8K => 0.00065, GgmlDType::Q8K => 0.00065,
_ => candle_core::bail!("No GGML results for quantization type {dtype:?}",), _ => bail!("No GGML results for quantization type {dtype:?}",),
}; };
Ok(err) Ok(err)
} }
/// Mirrores the GGML matmul unit test: https://github.com/ggerganov/llama.cpp/blob/master/tests/test-quantize-fns.cpp#L76-L91 /// Similar to the GGML matmul unit test:
/// https://github.com/ggerganov/llama.cpp/blob/master/tests/test-quantize-fns.cpp#L76-L91
fn ggml_matmul_error_test<T: GgmlType>() -> Result<()> { fn ggml_matmul_error_test<T: GgmlType>() -> Result<()> {
let a = create_ggml_like_vector(0.0); let a = create_ggml_like_vector(0.0);
let b = create_ggml_like_vector(1.0); let b = create_ggml_like_vector(1.0);
ggml_matmul_error_test_::<T>(a.as_slice(), b.as_slice(), 1.0)?;
// Another example that is more likely to trigger the overflow reported in #1526
let a = (0..GGML_TEST_SIZE)
.map(|i| i as f32 / GGML_TEST_SIZE as f32)
.collect::<Vec<_>>();
let b = (0..GGML_TEST_SIZE)
.map(|i| i as f32 / GGML_TEST_SIZE as f32)
.collect::<Vec<_>>();
ggml_matmul_error_test_::<T>(a.as_slice(), b.as_slice(), 2.0)?;
Ok(())
}
fn ggml_matmul_error_test_<T: GgmlType>(a: &[f32], b: &[f32], err_m: f32) -> Result<()> {
let length = a.len(); let length = a.len();
let mut a_quant = vec![T::zeros(); length / T::BLCK_SIZE]; let mut a_quant = vec![T::zeros(); length / T::BLCK_SIZE];
let mut b_quant = vec![T::VecDotType::zeros(); length / T::VecDotType::BLCK_SIZE]; let mut b_quant = vec![T::VecDotType::zeros(); length / T::VecDotType::BLCK_SIZE];
T::from_float(&a, &mut a_quant)?; T::from_float(a, &mut a_quant)?;
T::VecDotType::from_float(&b, &mut b_quant)?; T::VecDotType::from_float(b, &mut b_quant)?;
let result = T::vec_dot(length, &a_quant, &b_quant)?; let result = T::vec_dot(length, &a_quant, &b_quant)?;
let result_unopt = T::vec_dot_unopt(length, &a_quant, &b_quant)?; let result_unopt = T::vec_dot_unopt(length, &a_quant, &b_quant)?;
let reference_result = vec_dot_reference(&a, &b); let reference_result = vec_dot_reference(a, b);
if (result - result_unopt).abs() / length as f32 > 1e-6 { if (result - result_unopt).abs() / length as f32 > 1e-6 {
candle_core::bail!( bail!(
"the opt and unopt vec-dot returned different values, opt {result}, unopt {result_unopt}" "the opt and unopt vec-dot returned different values, opt {result}, unopt {result_unopt}"
) )
} }
let error = (result - reference_result).abs() / length as f32; let error = (result - reference_result).abs() / length as f32;
let ggml_error = ggml_reference_matmul_error(T::DTYPE)?; let ggml_error = ggml_reference_matmul_error(T::DTYPE)? * err_m;
if !error.is_finite() || error > GGML_MAX_DOT_PRODUCT_ERROR { if !error.is_finite() || error > GGML_MAX_DOT_PRODUCT_ERROR {
candle_core::bail!( bail!("Dot product error {error} exceeds max error {GGML_MAX_DOT_PRODUCT_ERROR}",);
"Dot product error {error} exceeds max error {GGML_MAX_DOT_PRODUCT_ERROR}",
);
} }
// We diverge slightly due to different rounding behavior / f16 to f32 conversions in GGML // We diverge slightly due to different rounding behavior / f16 to f32 conversions in GGML
// => we use a slightly higher error threshold // => we use a slightly higher error threshold
const ERROR_LENIENCY: f32 = 0.00001; const ERROR_LENIENCY: f32 = 0.00001;
if error - ERROR_LENIENCY > ggml_error { if error - ERROR_LENIENCY > ggml_error {
candle_core::bail!( bail!(
"Dot product error {} exceeds ggml reference error {}", "Dot product error {} exceeds ggml reference error {}",
error, error,
ggml_error ggml_error
@ -563,6 +578,16 @@ fn get_small_tensors(
Ok((lhs, rhs, mm)) Ok((lhs, rhs, mm))
} }
#[test]
fn quantized_mm() -> Result<()> {
ggml_matmul_error_test::<k_quants::BlockQ4_0>()?;
ggml_matmul_error_test::<k_quants::BlockQ4_1>()?;
ggml_matmul_error_test::<k_quants::BlockQ5_0>()?;
ggml_matmul_error_test::<k_quants::BlockQ5_1>()?;
ggml_matmul_error_test::<k_quants::BlockQ8_0>()?;
Ok(())
}
/// generates random tensors of size `m x k` and `n x k` and calculates their expected matrix multiplication result. /// generates random tensors of size `m x k` and `n x k` and calculates their expected matrix multiplication result.
fn get_random_tensors( fn get_random_tensors(
m: usize, m: usize,

View File

@ -11,8 +11,8 @@ readme = "README.md"
[dependencies] [dependencies]
byteorder = { workspace = true } byteorder = { workspace = true }
candle = { path = "../candle-core", version = "0.3.3", package = "candle-core" } candle = { workspace = true }
candle-nn = { path = "../candle-nn", version = "0.3.3" } candle-nn = { workspace = true }
hf-hub = { workspace = true} hf-hub = { workspace = true}
intel-mkl-src = { workspace = true, optional = true } intel-mkl-src = { workspace = true, optional = true }
memmap2 = { workspace = true } memmap2 = { workspace = true }

View File

@ -11,12 +11,12 @@ readme = "README.md"
[dependencies] [dependencies]
accelerate-src = { workspace = true, optional = true } accelerate-src = { workspace = true, optional = true }
candle = { path = "../candle-core", version = "0.3.3", package = "candle-core" } candle = { workspace = true }
candle-datasets = { path = "../candle-datasets", version = "0.3.3" } candle-datasets = { workspace = true }
candle-nn = { path = "../candle-nn", version = "0.3.3" } candle-nn = { workspace = true }
candle-transformers = { path = "../candle-transformers", version = "0.3.3" } candle-transformers = { workspace = true }
candle-flash-attn = { path = "../candle-flash-attn", version = "0.3.3", optional = true } candle-flash-attn = { workspace = true, optional = true }
candle-onnx = { path = "../candle-onnx", version = "0.3.3", optional = true } candle-onnx = { workspace = true, optional = true }
csv = "1.3.0" csv = "1.3.0"
cudarc = { workspace = true, optional = true } cudarc = { workspace = true, optional = true }
@ -49,11 +49,12 @@ tokio = "1.29.1"
[build-dependencies] [build-dependencies]
anyhow = { workspace = true } anyhow = { workspace = true }
bindgen_cuda = { version = "0.1.1", optional = true }
[features] [features]
default = [] default = []
accelerate = ["dep:accelerate-src", "candle/accelerate", "candle-nn/accelerate", "candle-transformers/accelerate"] accelerate = ["dep:accelerate-src", "candle/accelerate", "candle-nn/accelerate", "candle-transformers/accelerate"]
cuda = ["candle/cuda", "candle-nn/cuda", "candle-transformers/cuda"] cuda = ["candle/cuda", "candle-nn/cuda", "candle-transformers/cuda", "dep:bindgen_cuda"]
cudnn = ["candle/cudnn"] cudnn = ["candle/cudnn"]
flash-attn = ["cuda", "candle-transformers/flash-attn", "dep:candle-flash-attn"] flash-attn = ["cuda", "candle-transformers/flash-attn", "dep:candle-flash-attn"]
mkl = ["dep:intel-mkl-src", "candle/mkl", "candle-nn/mkl", "candle-transformers/mkl"] mkl = ["dep:intel-mkl-src", "candle/mkl", "candle-nn/mkl", "candle-transformers/mkl"]

View File

@ -4,251 +4,34 @@ use std::io::Write;
use std::path::PathBuf; use std::path::PathBuf;
struct KernelDirectories { struct KernelDirectories {
kernel_dir: &'static str, kernel_glob: &'static str,
rust_target: &'static str, rust_target: &'static str,
include_dirs: &'static [&'static str], include_dirs: &'static [&'static str],
} }
const DIRS: [KernelDirectories; 1] = [KernelDirectories { const KERNEL_DIRS: [KernelDirectories; 1] = [KernelDirectories {
kernel_dir: "examples/custom-ops/kernels/", kernel_glob: "examples/custom-ops/kernels/*.cu",
rust_target: "examples/custom-ops/cuda_kernels.rs", rust_target: "examples/custom-ops/cuda_kernels.rs",
include_dirs: &[], include_dirs: &[],
}]; }];
impl KernelDirectories {
fn maybe_build_ptx(
&self,
cu_file: &std::path::Path,
ptx_file: &std::path::Path,
compute_cap: usize,
) -> Result<()> {
let should_compile = if ptx_file.exists() {
let ptx_modified = ptx_file.metadata()?.modified()?;
let cu_modified = cu_file.metadata()?.modified()?;
cu_modified.duration_since(ptx_modified).is_ok()
} else {
true
};
if should_compile {
#[cfg(feature = "cuda")]
{
let ccbin_env = std::env::var("CANDLE_NVCC_CCBIN");
println!("cargo:rerun-if-env-changed=CANDLE_NVCC_CCBIN");
let mut command = std::process::Command::new("nvcc");
let out_dir = ptx_file.parent().context("no parent for ptx file")?;
let include_dirs: Vec<String> =
self.include_dirs.iter().map(|c| format!("-I{c}")).collect();
command
.arg(format!("--gpu-architecture=sm_{compute_cap}"))
.arg("--ptx")
.args(["--default-stream", "per-thread"])
.args(["--output-directory", out_dir.to_str().unwrap()])
.arg(format!("-I/{}", self.kernel_dir))
.args(include_dirs)
.arg(cu_file);
if let Ok(ccbin_path) = &ccbin_env {
command
.arg("-allow-unsupported-compiler")
.args(["-ccbin", ccbin_path]);
}
let output = command
.spawn()
.context("failed spawning nvcc")?
.wait_with_output()?;
if !output.status.success() {
anyhow::bail!(
"nvcc error while compiling {cu_file:?}:\n\n# stdout\n{:#}\n\n# stderr\n{:#}",
String::from_utf8_lossy(&output.stdout),
String::from_utf8_lossy(&output.stderr)
)
}
}
#[cfg(not(feature = "cuda"))]
std::fs::OpenOptions::new()
.create(true)
.write(true)
.open(ptx_file)?;
}
Ok(())
}
fn process(&self, out_dir: &std::path::Path, compute_cap: usize) -> Result<()> {
println!("cargo:rerun-if-changed={}", self.kernel_dir);
let kernel_dir = PathBuf::from(self.kernel_dir);
let out_dir = out_dir.join(self.kernel_dir);
if !out_dir.exists() {
std::fs::create_dir_all(&out_dir)?;
}
let mut cu_files = vec![];
let mut cuh_files = vec![];
for file in std::fs::read_dir(kernel_dir)?.flatten() {
let file = file.path();
match file.extension().and_then(|v| v.to_str()) {
Some("cu") => cu_files.push(file),
Some("cuh") => cuh_files.push(file),
_ => {}
}
}
let mut ptx_paths = vec![];
for cu_file in cu_files.iter() {
let file_stem = cu_file
.file_stem()
.with_context(|| format!("no stem {cu_file:?}"))?;
let file_stem = file_stem.to_string_lossy().into_owned();
let ptx_file = out_dir.join(&format!("{file_stem}.ptx"));
self.maybe_build_ptx(cu_file, &ptx_file, compute_cap)?;
ptx_paths.push(ptx_file);
}
let regenerate_rs_file = true;
if regenerate_rs_file {
let mut file = std::fs::File::create(self.rust_target)?;
for ptx_path in ptx_paths {
let name = ptx_path
.file_stem()
.context("empty stem")?
.to_string_lossy();
file.write_all(b"#[rustfmt::skip]\n")?;
let const_definition = format!(
r#"pub const {}: &str = include_str!(concat!(env!("OUT_DIR"), "/{}/{name}.ptx"));"#,
name.to_uppercase().replace('.', "_"),
self.kernel_dir,
);
file.write_all(const_definition.as_bytes())?;
file.write_all(b"\n")?;
}
}
Ok(())
}
}
fn main() -> Result<()> { fn main() -> Result<()> {
println!("cargo:rerun-if-changed=build.rs"); println!("cargo:rerun-if-changed=build.rs");
let out_dir = std::env::var("OUT_DIR").context("OUT_DIR not set")?;
let out_dir = PathBuf::from(out_dir);
#[cfg(feature = "cuda")] #[cfg(feature = "cuda")]
set_cuda_include_dir()?; {
#[cfg(feature = "cuda")] for kdir in KERNEL_DIRS.iter() {
let compute_cap = compute_cap()?; let builder = bindgen_cuda::Builder::default().kernel_paths_glob(kdir.kernel_glob);
println!("cargo:info={builder:?}");
let bindings = builder.build_ptx().unwrap();
bindings.write(kdir.rust_target).unwrap()
}
}
#[cfg(not(feature = "cuda"))] #[cfg(not(feature = "cuda"))]
let compute_cap = 0; {
for d in DIRS { for kdir in KERNEL_DIRS.iter() {
d.process(&out_dir, compute_cap)? let _file = std::fs::File::create(kdir.rust_target)?;
}
} }
Ok(()) Ok(())
} }
fn set_cuda_include_dir() -> Result<()> {
// NOTE: copied from cudarc build.rs.
let env_vars = [
"CUDA_PATH",
"CUDA_ROOT",
"CUDA_TOOLKIT_ROOT_DIR",
"CUDNN_LIB",
];
let env_vars = env_vars
.into_iter()
.map(std::env::var)
.filter_map(Result::ok)
.map(Into::<PathBuf>::into);
let roots = [
"/usr",
"/usr/local/cuda",
"/opt/cuda",
"/usr/lib/cuda",
"C:/Program Files/NVIDIA GPU Computing Toolkit",
"C:/CUDA",
];
let roots = roots.into_iter().map(Into::<PathBuf>::into);
let root = env_vars
.chain(roots)
.find(|path| path.join("include").join("cuda.h").is_file())
.context("cannot find include/cuda.h")?;
println!(
"cargo:rustc-env=CUDA_INCLUDE_DIR={}",
root.join("include").display()
);
Ok(())
}
#[allow(unused)]
fn compute_cap() -> Result<usize> {
println!("cargo:rerun-if-env-changed=CUDA_COMPUTE_CAP");
// Try to parse compute cap from env
let mut compute_cap = if let Ok(compute_cap_str) = std::env::var("CUDA_COMPUTE_CAP") {
println!("cargo:rustc-env=CUDA_COMPUTE_CAP={compute_cap_str}");
compute_cap_str
.parse::<usize>()
.context("Could not parse code")?
} else {
// Grab compute cap from nvidia-smi
let out = std::process::Command::new("nvidia-smi")
.arg("--query-gpu=compute_cap")
.arg("--format=csv")
.output()
.context("`nvidia-smi` failed. Ensure that you have CUDA installed and that `nvidia-smi` is in your PATH.")?;
let out = std::str::from_utf8(&out.stdout).context("stdout is not a utf8 string")?;
let mut lines = out.lines();
assert_eq!(
lines.next().context("missing line in stdout")?,
"compute_cap"
);
let cap = lines
.next()
.context("missing line in stdout")?
.replace('.', "");
println!("cargo:rustc-env=CUDA_COMPUTE_CAP={cap}");
cap.parse::<usize>()
.with_context(|| format!("cannot parse as int {cap}"))?
};
// Grab available GPU codes from nvcc and select the highest one
let max_nvcc_code = {
let out = std::process::Command::new("nvcc")
.arg("--list-gpu-code")
.output()
.expect("`nvcc` failed. Ensure that you have CUDA installed and that `nvcc` is in your PATH.");
let out = std::str::from_utf8(&out.stdout).unwrap();
let out = out.lines().collect::<Vec<&str>>();
let mut codes = Vec::with_capacity(out.len());
for code in out {
let code = code.split('_').collect::<Vec<&str>>();
if !code.is_empty() && code.contains(&"sm") {
if let Ok(num) = code[1].parse::<usize>() {
codes.push(num);
}
}
}
codes.sort();
if !codes.contains(&compute_cap) {
anyhow::bail!(
"nvcc cannot target gpu arch {compute_cap}. Available nvcc targets are {codes:?}."
);
}
*codes.last().unwrap()
};
// If nvidia-smi compute_cap is higher than the highest gpu code from nvcc,
// then choose the highest gpu code in nvcc
if compute_cap > max_nvcc_code {
println!(
"cargo:warning=Lowering gpu arch {compute_cap} to max nvcc target {max_nvcc_code}."
);
compute_cap = max_nvcc_code;
}
println!("cargo:rerun-if-env-changed=CUDA_COMPUTE_CAP");
if let Ok(compute_cap_str) = std::env::var("CUDA_COMPUTE_CAP") {
compute_cap = compute_cap_str
.parse::<usize>()
.with_context(|| format!("cannot parse as usize '{compute_cap_str}'"))?;
println!("cargo:warning=Using gpu arch {compute_cap} from $CUDA_COMPUTE_CAP");
}
println!("cargo:rustc-env=CUDA_COMPUTE_CAP=sm_{compute_cap}");
Ok(compute_cap)
}

View File

@ -1,2 +1 @@
#[rustfmt::skip] pub const LAYERNORM_KERNELS: &str = include_str!(concat!(env!("OUT_DIR"), "/layernorm_kernels.ptx"));
pub const LAYERNORM_KERNELS: &str = include_str!(concat!(env!("OUT_DIR"), "/examples/custom-ops/kernels//layernorm_kernels.ptx"));

View File

@ -6,7 +6,8 @@
#[cfg(feature = "mkl")] #[cfg(feature = "mkl")]
extern crate intel_mkl_src; extern crate intel_mkl_src;
#[allow(unused)] #[rustfmt::skip]
#[cfg(feature = "cuda")]
mod cuda_kernels; mod cuda_kernels;
use clap::Parser; use clap::Parser;

View File

@ -165,14 +165,14 @@ fn main() -> Result<()> {
let mut index_pos = 0; let mut index_pos = 0;
let mut token_generated = 0; let mut token_generated = 0;
for index in 0..args.sample_len { for index in 0..args.sample_len {
let context_size = if cache.use_kv_cache && index > 0 { let (context_size, context_index) = if cache.use_kv_cache && index > 0 {
1 (1, index_pos)
} else { } else {
tokens.len() (tokens.len(), 0)
}; };
let ctxt = &tokens[tokens.len().saturating_sub(context_size)..]; let ctxt = &tokens[tokens.len().saturating_sub(context_size)..];
let input = Tensor::new(ctxt, &device)?.unsqueeze(0)?; let input = Tensor::new(ctxt, &device)?.unsqueeze(0)?;
let logits = llama.forward(&input, index_pos)?; let logits = llama.forward(&input, context_index)?;
let logits = logits.squeeze(0)?; let logits = logits.squeeze(0)?;
let logits = if args.repeat_penalty == 1. { let logits = if args.repeat_penalty == 1. {
logits logits

View File

@ -11,14 +11,14 @@ license = "MIT OR Apache-2.0"
readme = "README.md" readme = "README.md"
[dependencies] [dependencies]
candle = { path = "../candle-core", features = ["cuda"], version = "0.3.3", package = "candle-core" } candle = { path = "../candle-core", features = ["cuda"], package = "candle-core" }
half = { version = "2.3.1", features = ["num-traits"] } half = { version = "2.3.1", features = ["num-traits"] }
[build-dependencies] [build-dependencies]
bindgen_cuda = "0.1.1"
anyhow = { version = "1", features = ["backtrace"] } anyhow = { version = "1", features = ["backtrace"] }
num_cpus = "1.15.0"
rayon = "1.7.0"
[dev-dependencies] [dev-dependencies]
anyhow = { version = "1", features = ["backtrace"] } anyhow = { version = "1", features = ["backtrace"] }
candle-nn = { path = "../candle-nn", version = "0.3.3", features = ["cuda"] } candle-nn = { path = "../candle-nn", features = ["cuda"] }

View File

@ -2,44 +2,32 @@
// The cuda build time is very long so one can set the CANDLE_FLASH_ATTN_BUILD_DIR environment // The cuda build time is very long so one can set the CANDLE_FLASH_ATTN_BUILD_DIR environment
// variable in order to cache the compiled artifacts and avoid recompiling too often. // variable in order to cache the compiled artifacts and avoid recompiling too often.
use anyhow::{Context, Result}; use anyhow::{Context, Result};
use rayon::prelude::*;
use std::path::PathBuf; use std::path::PathBuf;
use std::str::FromStr;
const KERNEL_FILES: [&str; 17] = [ const KERNEL_FILES: [&str; 17] = [
"flash_api.cu", "kernels/flash_api.cu",
"flash_fwd_hdim128_fp16_sm80.cu", "kernels/flash_fwd_hdim128_fp16_sm80.cu",
"flash_fwd_hdim160_fp16_sm80.cu", "kernels/flash_fwd_hdim160_fp16_sm80.cu",
"flash_fwd_hdim192_fp16_sm80.cu", "kernels/flash_fwd_hdim192_fp16_sm80.cu",
"flash_fwd_hdim224_fp16_sm80.cu", "kernels/flash_fwd_hdim224_fp16_sm80.cu",
"flash_fwd_hdim256_fp16_sm80.cu", "kernels/flash_fwd_hdim256_fp16_sm80.cu",
"flash_fwd_hdim32_fp16_sm80.cu", "kernels/flash_fwd_hdim32_fp16_sm80.cu",
"flash_fwd_hdim64_fp16_sm80.cu", "kernels/flash_fwd_hdim64_fp16_sm80.cu",
"flash_fwd_hdim96_fp16_sm80.cu", "kernels/flash_fwd_hdim96_fp16_sm80.cu",
"flash_fwd_hdim128_bf16_sm80.cu", "kernels/flash_fwd_hdim128_bf16_sm80.cu",
"flash_fwd_hdim160_bf16_sm80.cu", "kernels/flash_fwd_hdim160_bf16_sm80.cu",
"flash_fwd_hdim192_bf16_sm80.cu", "kernels/flash_fwd_hdim192_bf16_sm80.cu",
"flash_fwd_hdim224_bf16_sm80.cu", "kernels/flash_fwd_hdim224_bf16_sm80.cu",
"flash_fwd_hdim256_bf16_sm80.cu", "kernels/flash_fwd_hdim256_bf16_sm80.cu",
"flash_fwd_hdim32_bf16_sm80.cu", "kernels/flash_fwd_hdim32_bf16_sm80.cu",
"flash_fwd_hdim64_bf16_sm80.cu", "kernels/flash_fwd_hdim64_bf16_sm80.cu",
"flash_fwd_hdim96_bf16_sm80.cu", "kernels/flash_fwd_hdim96_bf16_sm80.cu",
]; ];
fn main() -> Result<()> { fn main() -> Result<()> {
let num_cpus = std::env::var("RAYON_NUM_THREADS").map_or_else(
|_| num_cpus::get_physical(),
|s| usize::from_str(&s).unwrap(),
);
rayon::ThreadPoolBuilder::new()
.num_threads(num_cpus)
.build_global()
.unwrap();
println!("cargo:rerun-if-changed=build.rs"); println!("cargo:rerun-if-changed=build.rs");
for kernel_file in KERNEL_FILES.iter() { for kernel_file in KERNEL_FILES.iter() {
println!("cargo:rerun-if-changed=kernels/{kernel_file}"); println!("cargo:rerun-if-changed={kernel_file}");
} }
println!("cargo:rerun-if-changed=kernels/flash_fwd_kernel.h"); println!("cargo:rerun-if-changed=kernels/flash_fwd_kernel.h");
println!("cargo:rerun-if-changed=kernels/flash_fwd_launch_template.h"); println!("cargo:rerun-if-changed=kernels/flash_fwd_launch_template.h");
@ -66,223 +54,30 @@ fn main() -> Result<()> {
)) ))
} }
}; };
set_cuda_include_dir()?;
let ccbin_env = std::env::var("CANDLE_NVCC_CCBIN"); let kernels = KERNEL_FILES.iter().collect();
println!("cargo:rerun-if-env-changed=CANDLE_NVCC_CCBIN"); let builder = bindgen_cuda::Builder::default()
.kernel_paths(kernels)
let compute_cap = compute_cap()?; .out_dir(build_dir.clone())
.arg("-std=c++17")
.arg("-O3")
.arg("-U__CUDA_NO_HALF_OPERATORS__")
.arg("-U__CUDA_NO_HALF_CONVERSIONS__")
.arg("-U__CUDA_NO_HALF2_OPERATORS__")
.arg("-U__CUDA_NO_BFLOAT16_CONVERSIONS__")
.arg("-Icutlass/include")
.arg("--expt-relaxed-constexpr")
.arg("--expt-extended-lambda")
.arg("--use_fast_math")
.arg("--verbose");
let out_file = build_dir.join("libflashattention.a"); let out_file = build_dir.join("libflashattention.a");
builder.build_lib(out_file);
let kernel_dir = PathBuf::from("kernels");
let cu_files: Vec<_> = KERNEL_FILES
.iter()
.map(|f| {
let mut obj_file = out_dir.join(f);
obj_file.set_extension("o");
(kernel_dir.join(f), obj_file)
})
.collect();
let out_modified: Result<_, _> = out_file.metadata().and_then(|m| m.modified());
let should_compile = if out_file.exists() {
kernel_dir
.read_dir()
.expect("kernels folder should exist")
.any(|entry| {
if let (Ok(entry), Ok(out_modified)) = (entry, &out_modified) {
let in_modified = entry.metadata().unwrap().modified().unwrap();
in_modified.duration_since(*out_modified).is_ok()
} else {
true
}
})
} else {
true
};
if should_compile {
cu_files
.par_iter()
.map(|(cu_file, obj_file)| {
let mut command = std::process::Command::new("nvcc");
command
.arg("-std=c++17")
.arg("-O3")
.arg("-U__CUDA_NO_HALF_OPERATORS__")
.arg("-U__CUDA_NO_HALF_CONVERSIONS__")
.arg("-U__CUDA_NO_HALF2_OPERATORS__")
.arg("-U__CUDA_NO_BFLOAT16_CONVERSIONS__")
.arg(format!("--gpu-architecture=sm_{compute_cap}"))
.arg("-c")
.args(["-o", obj_file.to_str().unwrap()])
.args(["--default-stream", "per-thread"])
.arg("-Icutlass/include")
.arg("--expt-relaxed-constexpr")
.arg("--expt-extended-lambda")
.arg("--use_fast_math")
.arg("--verbose");
if let Ok(ccbin_path) = &ccbin_env {
command
.arg("-allow-unsupported-compiler")
.args(["-ccbin", ccbin_path]);
}
command.arg(cu_file);
let output = command
.spawn()
.context("failed spawning nvcc")?
.wait_with_output()?;
if !output.status.success() {
anyhow::bail!(
"nvcc error while executing compiling: {:?}\n\n# stdout\n{:#}\n\n# stderr\n{:#}",
&command,
String::from_utf8_lossy(&output.stdout),
String::from_utf8_lossy(&output.stderr)
)
}
Ok(())
})
.collect::<Result<()>>()?;
let obj_files = cu_files.iter().map(|c| c.1.clone()).collect::<Vec<_>>();
let mut command = std::process::Command::new("nvcc");
command
.arg("--lib")
.args(["-o", out_file.to_str().unwrap()])
.args(obj_files);
let output = command
.spawn()
.context("failed spawning nvcc")?
.wait_with_output()?;
if !output.status.success() {
anyhow::bail!(
"nvcc error while linking: {:?}\n\n# stdout\n{:#}\n\n# stderr\n{:#}",
&command,
String::from_utf8_lossy(&output.stdout),
String::from_utf8_lossy(&output.stderr)
)
}
}
println!("cargo:rustc-link-search={}", build_dir.display()); println!("cargo:rustc-link-search={}", build_dir.display());
println!("cargo:rustc-link-lib=flashattention"); println!("cargo:rustc-link-lib=flashattention");
println!("cargo:rustc-link-lib=dylib=cudart"); println!("cargo:rustc-link-lib=dylib=cudart");
println!("cargo:rustc-link-lib=dylib=stdc++"); println!("cargo:rustc-link-lib=dylib=stdc++");
/* laurent: I tried using the cc cuda integration as below but this lead to ptaxs never
finishing to run for some reason. Calling nvcc manually worked fine.
cc::Build::new()
.cuda(true)
.include("cutlass/include")
.flag("--expt-relaxed-constexpr")
.flag("--default-stream")
.flag("per-thread")
.flag(&format!("--gpu-architecture=sm_{compute_cap}"))
.file("kernels/flash_fwd_hdim32_fp16_sm80.cu")
.compile("flashattn");
*/
Ok(()) Ok(())
} }
fn set_cuda_include_dir() -> Result<()> {
// NOTE: copied from cudarc build.rs.
let env_vars = [
"CUDA_PATH",
"CUDA_ROOT",
"CUDA_TOOLKIT_ROOT_DIR",
"CUDNN_LIB",
];
let env_vars = env_vars
.into_iter()
.map(std::env::var)
.filter_map(Result::ok)
.map(Into::<PathBuf>::into);
let roots = [
"/usr",
"/usr/local/cuda",
"/opt/cuda",
"/usr/lib/cuda",
"C:/Program Files/NVIDIA GPU Computing Toolkit",
"C:/CUDA",
];
let roots = roots.into_iter().map(Into::<PathBuf>::into);
let root = env_vars
.chain(roots)
.find(|path| path.join("include").join("cuda.h").is_file())
.context("cannot find include/cuda.h")?;
println!(
"cargo:rustc-env=CUDA_INCLUDE_DIR={}",
root.join("include").display()
);
Ok(())
}
#[allow(unused)]
fn compute_cap() -> Result<usize> {
println!("cargo:rerun-if-env-changed=CUDA_COMPUTE_CAP");
// Try to parse compute caps from env
let mut compute_cap = if let Ok(compute_cap_str) = std::env::var("CUDA_COMPUTE_CAP") {
println!("cargo:rustc-env=CUDA_COMPUTE_CAP={compute_cap_str}");
compute_cap_str
.parse::<usize>()
.context("Could not parse compute cap")?
} else {
// Use nvidia-smi to get the current compute cap
let out = std::process::Command::new("nvidia-smi")
.arg("--query-gpu=compute_cap")
.arg("--format=csv")
.output()
.context("`nvidia-smi` failed. Ensure that you have CUDA installed and that `nvidia-smi` is in your PATH.")?;
let out = std::str::from_utf8(&out.stdout).context("stdout is not a utf8 string")?;
let mut lines = out.lines();
assert_eq!(
lines.next().context("missing line in stdout")?,
"compute_cap"
);
let cap = lines
.next()
.context("missing line in stdout")?
.replace('.', "");
let cap = cap
.parse::<usize>()
.with_context(|| format!("cannot parse as int {cap}"))?;
println!("cargo:rustc-env=CUDA_COMPUTE_CAP={cap}");
cap
};
// Grab available GPU codes from nvcc and select the highest one
let (supported_nvcc_codes, max_nvcc_code) = {
let out = std::process::Command::new("nvcc")
.arg("--list-gpu-code")
.output()
.expect("`nvcc` failed. Ensure that you have CUDA installed and that `nvcc` is in your PATH.");
let out = std::str::from_utf8(&out.stdout).unwrap();
let out = out.lines().collect::<Vec<&str>>();
let mut codes = Vec::with_capacity(out.len());
for code in out {
let code = code.split('_').collect::<Vec<&str>>();
if !code.is_empty() && code.contains(&"sm") {
if let Ok(num) = code[1].parse::<usize>() {
codes.push(num);
}
}
}
codes.sort();
let max_nvcc_code = *codes.last().context("no gpu codes parsed from nvcc")?;
(codes, max_nvcc_code)
};
// Check that nvcc supports the asked compute caps
if !supported_nvcc_codes.contains(&compute_cap) {
anyhow::bail!(
"nvcc cannot target gpu arch {compute_cap}. Available nvcc targets are {supported_nvcc_codes:?}."
);
}
if compute_cap > max_nvcc_code {
anyhow::bail!(
"CUDA compute cap {compute_cap} is higher than the highest gpu code from nvcc {max_nvcc_code}"
);
}
Ok(compute_cap)
}

View File

@ -0,0 +1,62 @@
#include <cmath>
#include <cute/tensor.hpp>
#include <cutlass/cutlass.h>
#include <cutlass/array.h>
#include "utils.h"
namespace flash {
using namespace cute;
////////////////////////////////////////////////////////////////////////////////////////////////////
template <bool Is_causal, typename Engine, typename Layout>
inline __device__ void apply_alibi(Tensor<Engine, Layout> &tensor,
const int col_idx_offset_,
const int max_seqlen_k,
const int row_idx_offset,
const int max_seqlen_q,
const int warp_row_stride,
const float alibi_slope) {
// tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N))
static_assert(Layout::rank == 2, "Only support 2D Tensor");
const int lane_id = threadIdx.x % 32;
const int col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2;
if constexpr (Is_causal) { // Simpler, we add the same bias vector to all rows
#pragma unroll
for (int nj = 0; nj < size<1, 1>(tensor); ++nj) {
const int col_idx_base = col_idx_offset + nj * 8;
#pragma unroll
for (int j = 0; j < size<1, 0>(tensor); ++j) {
const int col_idx = col_idx_base + j;
#pragma unroll
for (int mi = 0; mi < size<0>(tensor); ++mi) {
tensor(mi, make_coord(j, nj)) += alibi_slope * col_idx;
}
}
}
} else { // Bias depends on both row_idx and col_idx
#pragma unroll
for (int mi = 0; mi < size<0, 1>(tensor); ++mi) {
const int row_idx_base = row_idx_offset + mi * warp_row_stride;
#pragma unroll
for (int i = 0; i < size<0, 0>(tensor); ++i) {
const int row_idx = row_idx_base + i * 8;
#pragma unroll
for (int nj = 0; nj < size<1, 1>(tensor); ++nj) {
const int col_idx_base = col_idx_offset + nj * 8;
#pragma unroll
for (int j = 0; j < size<1, 0>(tensor); ++j) {
const int col_idx = col_idx_base + j;
tensor(make_coord(i, mi), make_coord(j, nj)) -= alibi_slope * abs(row_idx + max_seqlen_k - max_seqlen_q - col_idx);
}
}
}
}
}
}
} // namespace flash

View File

@ -14,9 +14,12 @@ struct BlockInfo {
template<typename Params> template<typename Params>
__device__ BlockInfo(const Params &params, const int bidb) __device__ BlockInfo(const Params &params, const int bidb)
: sum_s_q(!Varlen || params.cu_seqlens_q == nullptr ? -1 : params.cu_seqlens_q[bidb]) : sum_s_q(!Varlen || params.cu_seqlens_q == nullptr ? -1 : params.cu_seqlens_q[bidb])
, sum_s_k(!Varlen || params.cu_seqlens_k == nullptr ? -1 : params.cu_seqlens_k[bidb]) , sum_s_k(!Varlen || params.cu_seqlens_k == nullptr || !params.is_seqlens_k_cumulative ? -1 : params.cu_seqlens_k[bidb])
, actual_seqlen_q(!Varlen || params.cu_seqlens_q == nullptr ? params.seqlen_q : params.cu_seqlens_q[bidb + 1] - sum_s_q) , actual_seqlen_q(!Varlen || params.cu_seqlens_q == nullptr ? params.seqlen_q : params.cu_seqlens_q[bidb + 1] - sum_s_q)
, actual_seqlen_k(!Varlen || params.cu_seqlens_k == nullptr ? params.seqlen_k : params.cu_seqlens_k[bidb + 1] - sum_s_k) // If is_seqlens_k_cumulative, then seqlen_k is cu_seqlens_k[bidb + 1] - cu_seqlens_k[bidb].
// Otherwise it's cu_seqlens_k[bidb], i.e., we use cu_seqlens_k to store the sequence lengths of K.
, seqlen_k_cache(!Varlen || params.cu_seqlens_k == nullptr ? params.seqlen_k : (params.is_seqlens_k_cumulative ? params.cu_seqlens_k[bidb + 1] - sum_s_k : params.cu_seqlens_k[bidb]))
, actual_seqlen_k(params.seqused_k ? params.seqused_k[bidb] : seqlen_k_cache + (params.knew_ptr == nullptr ? 0 : params.seqlen_knew))
{ {
} }
@ -32,8 +35,10 @@ struct BlockInfo {
const int sum_s_q; const int sum_s_q;
const int sum_s_k; const int sum_s_k;
const uint32_t actual_seqlen_q; const int actual_seqlen_q;
const uint32_t actual_seqlen_k; // We have to have seqlen_k_cache declared before actual_seqlen_k, otherwise actual_seqlen_k is set to 0.
const int seqlen_k_cache;
const int actual_seqlen_k;
}; };
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -7,15 +7,6 @@
#include <cuda.h> #include <cuda.h>
#include <vector> #include <vector>
// #ifdef OLD_GENERATOR_PATH
// #include <ATen/CUDAGeneratorImpl.h>
// #else
// #include <ATen/cuda/CUDAGeneratorImpl.h>
// #endif
//
// #include <ATen/cuda/CUDAGraphsUtils.cuh>
constexpr int TOTAL_DIM = 0; constexpr int TOTAL_DIM = 0;
constexpr int H_DIM = 1; constexpr int H_DIM = 1;
constexpr int D_DIM = 2; constexpr int D_DIM = 2;
@ -53,6 +44,7 @@ struct Flash_fwd_params : public Qkv_params {
// The O matrix (output). // The O matrix (output).
void * __restrict__ o_ptr; void * __restrict__ o_ptr;
void * __restrict__ oaccum_ptr;
// The stride between rows of O. // The stride between rows of O.
index_t o_batch_stride; index_t o_batch_stride;
@ -64,9 +56,10 @@ struct Flash_fwd_params : public Qkv_params {
// The pointer to the softmax sum. // The pointer to the softmax sum.
void * __restrict__ softmax_lse_ptr; void * __restrict__ softmax_lse_ptr;
void * __restrict__ softmax_lseaccum_ptr;
// The dimensions. // The dimensions.
int b, seqlen_q, seqlen_k, d, seqlen_q_rounded, seqlen_k_rounded, d_rounded; int b, seqlen_q, seqlen_k, seqlen_knew, d, seqlen_q_rounded, seqlen_k_rounded, d_rounded, rotary_dim;
// The scaling factors for the kernel. // The scaling factors for the kernel.
float scale_softmax; float scale_softmax;
@ -76,8 +69,30 @@ struct Flash_fwd_params : public Qkv_params {
int * __restrict__ cu_seqlens_q; int * __restrict__ cu_seqlens_q;
int * __restrict__ cu_seqlens_k; int * __restrict__ cu_seqlens_k;
// If provided, the actual length of each k sequence.
int * __restrict__ seqused_k;
int *__restrict__ blockmask; int *__restrict__ blockmask;
// The K_new and V_new matrices.
void * __restrict__ knew_ptr;
void * __restrict__ vnew_ptr;
// The stride between rows of the Q, K and V matrices.
index_t knew_batch_stride;
index_t vnew_batch_stride;
index_t knew_row_stride;
index_t vnew_row_stride;
index_t knew_head_stride;
index_t vnew_head_stride;
// The cos and sin matrices for rotary embedding.
void * __restrict__ rotary_cos_ptr;
void * __restrict__ rotary_sin_ptr;
// The indices to index into the KV cache.
int *__restrict__ cache_batch_idx;
// The dropout probability (probability of keeping an activation). // The dropout probability (probability of keeping an activation).
float p_dropout; float p_dropout;
// uint32_t p_dropout_in_uint; // uint32_t p_dropout_in_uint;
@ -88,11 +103,22 @@ struct Flash_fwd_params : public Qkv_params {
float rp_dropout; float rp_dropout;
float scale_softmax_rp_dropout; float scale_softmax_rp_dropout;
// Random state. // Local window size
// at::PhiloxCudaState philox_args; int window_size_left, window_size_right;
bool is_bf16; bool is_bf16;
bool is_causal; bool is_causal;
// If is_seqlens_k_cumulative, then seqlen_k is cu_seqlens_k[bidb + 1] - cu_seqlens_k[bidb].
// Otherwise it's cu_seqlens_k[bidb], i.e., we use cu_seqlens_k to store the sequence lengths of K.
bool is_seqlens_k_cumulative;
bool is_rotary_interleaved;
int num_splits; // For split-KV version
void * __restrict__ alibi_slopes_ptr;
index_t alibi_slopes_batch_stride;
}; };
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
@ -132,10 +158,14 @@ struct Flash_bwd_params : public Flash_fwd_params {
// The pointer to the softmax d sum. // The pointer to the softmax d sum.
void *__restrict__ dsoftmax_sum; void *__restrict__ dsoftmax_sum;
bool deterministic;
index_t dq_accum_split_stride;
}; };
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T, int Headdim> void run_mha_fwd_(Flash_fwd_params &params, cudaStream_t stream); template<typename T, int Headdim> void run_mha_fwd_(Flash_fwd_params &params, cudaStream_t stream);
template<typename T, int Headdim> void run_mha_fwd_splitkv_dispatch(Flash_fwd_params &params, cudaStream_t stream);
template<typename T, int Headdim> void run_mha_bwd_(Flash_bwd_params &params, cudaStream_t stream, const bool configure); template<typename T, int Headdim> void run_mha_bwd_(Flash_bwd_params &params, cudaStream_t stream, const bool configure);

View File

@ -1,17 +1,15 @@
#include "flash_fwd_launch_template.h" #include "flash_fwd_launch_template.h"
// void run_mha_fwd(Flash_fwd_params &params, cudaStream_t stream) { void run_mha_fwd(Flash_fwd_params &params, cudaStream_t stream, bool force_split_kernel=false) {
// FWD_HEADDIM_SWITCH(params.d, [&] { FP16_SWITCH(!params.is_bf16, [&] {
// run_mha_fwd_<cutlass::half_t, kHeadDim>(params, stream); FWD_HEADDIM_SWITCH(params.d, [&] {
// }); // if (params.num_splits <= 1 && !force_split_kernel) { // If we don't set it num_splits == 0
// } run_mha_fwd_<elem_type, kHeadDim>(params, stream);
// } else {
void run_mha_fwd(Flash_fwd_params &params, cudaStream_t stream) { // run_mha_fwd_splitkv_dispatch<elem_type, kHeadDim>(params, stream);
FP16_SWITCH(!params.is_bf16, [&] { // }
FWD_HEADDIM_SWITCH(params.d, [&] { });
run_mha_fwd_<elem_type, kHeadDim>(params, stream); });
});
});
} }
extern "C" void run_mha( extern "C" void run_mha(
@ -20,6 +18,7 @@ extern "C" void run_mha(
void *v_ptr, void *v_ptr,
void *o_ptr, void *o_ptr,
void *softmax_lse_ptr, void *softmax_lse_ptr,
void *alibi_slopes_ptr,
int32_t *cu_seqlens_q_ptr, int32_t *cu_seqlens_q_ptr,
int32_t *cu_seqlens_k_ptr, int32_t *cu_seqlens_k_ptr,
@ -28,6 +27,7 @@ extern "C" void run_mha(
uint32_t k_batch_stride, uint32_t k_batch_stride,
uint32_t v_batch_stride, uint32_t v_batch_stride,
uint32_t o_batch_stride, uint32_t o_batch_stride,
uint32_t alibi_slopes_batch_stride,
uint32_t q_row_stride, uint32_t q_row_stride,
uint32_t k_row_stride, uint32_t k_row_stride,
@ -51,8 +51,11 @@ extern "C" void run_mha(
uint32_t seqlen_q_rounded, uint32_t seqlen_q_rounded,
uint32_t seqlen_k_rounded, uint32_t seqlen_k_rounded,
int is_bf16,
int is_causal, int is_causal,
int is_bf16
int window_size_left,
int window_size_right
) { ) {
Flash_fwd_params params; Flash_fwd_params params;
// Reset the parameters // Reset the parameters
@ -65,12 +68,14 @@ extern "C" void run_mha(
params.o_ptr = o_ptr; params.o_ptr = o_ptr;
params.softmax_lse_ptr = softmax_lse_ptr; params.softmax_lse_ptr = softmax_lse_ptr;
params.alibi_slopes_ptr = alibi_slopes_ptr;
// All stride are in elements, not bytes. // All stride are in elements, not bytes.
params.q_batch_stride = q_batch_stride; params.q_batch_stride = q_batch_stride;
params.k_batch_stride = k_batch_stride; params.k_batch_stride = k_batch_stride;
params.v_batch_stride = v_batch_stride; params.v_batch_stride = v_batch_stride;
params.o_batch_stride = o_batch_stride; params.o_batch_stride = o_batch_stride;
params.alibi_slopes_batch_stride = alibi_slopes_batch_stride;
params.q_row_stride = q_row_stride; params.q_row_stride = q_row_stride;
params.k_row_stride = k_row_stride; params.k_row_stride = k_row_stride;
@ -92,7 +97,6 @@ extern "C" void run_mha(
params.seqlen_k_rounded = seqlen_k_rounded; params.seqlen_k_rounded = seqlen_k_rounded;
params.d = d; params.d = d;
params.d_rounded = d_rounded; params.d_rounded = d_rounded;
params.is_causal = is_causal;
// Set the different scale values. // Set the different scale values.
params.scale_softmax = softmax_scale; params.scale_softmax = softmax_scale;
@ -106,6 +110,14 @@ extern "C" void run_mha(
params.cu_seqlens_q = cu_seqlens_q_ptr; params.cu_seqlens_q = cu_seqlens_q_ptr;
params.cu_seqlens_k = cu_seqlens_k_ptr; params.cu_seqlens_k = cu_seqlens_k_ptr;
params.p_ptr = nullptr; // used for `return_softmax`. params.p_ptr = nullptr; // used for `return_softmax`.
params.seqused_k = nullptr;
params.is_causal = is_causal;
params.window_size_left = window_size_left;
params.window_size_right = window_size_right;
params.is_seqlens_k_cumulative = true;
params.num_splits = 1;
cudaStream_t stream = 0; // Use the default stream. cudaStream_t stream = 0; // Use the default stream.
run_mha_fwd(params, stream); run_mha_fwd(params, stream);

View File

@ -1,19 +1,10 @@
// Copyright (c) 2023, Tri Dao. // Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation. // Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_fwd_launch_template.h" #include "flash_fwd_launch_template.h"
// template<>
// void run_mha_fwd_<cutlass::bfloat16_t, 128>(Flash_fwd_params &params, cudaStream_t stream) {
// using elem_type = cutlass::bfloat16_t;
// if (params.p_dropout == 1.f) {
// run_flash_fwd<Flash_fwd_kernel_traits<128, 128, 64, 4, false, false, elem_type>, false>(params, stream);
// } else {
// run_flash_fwd<Flash_fwd_kernel_traits<128, 128, 32, 4, false, false, elem_type>, true>(params, stream);
// }
// }
template<> template<>
void run_mha_fwd_<cutlass::bfloat16_t, 128>(Flash_fwd_params &params, cudaStream_t stream) { void run_mha_fwd_<cutlass::bfloat16_t, 128>(Flash_fwd_params &params, cudaStream_t stream) {
run_mha_fwd_hdim128<cutlass::bfloat16_t>(params, stream); run_mha_fwd_hdim128<cutlass::bfloat16_t>(params, stream);
} }

View File

@ -1,32 +1,10 @@
// Copyright (c) 2023, Tri Dao. // Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation. // Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_fwd_launch_template.h" #include "flash_fwd_launch_template.h"
// template<>
// void run_mha_fwd_<cutlass::half_t, 128>(Flash_fwd_params &params, cudaStream_t stream) {
// using elem_type = cutlass::half_t;
// if (params.p_dropout == 1.f) {
// // Using 8 warps (128 x 128 and 256 x 64) is 28% slower for seqlen=2k
// run_flash_fwd<Flash_fwd_kernel_traits<128, 128, 64, 4, false, false, elem_type>, false>(params, stream);
// // run_flash_fwd<Flash_fwd_kernel_traits<128, 128, 64, 4, true, false, elem_type>, false>(params, stream);
// // run_flash_fwd<Flash_fwd_kernel_traits<128, 128, 64, 4, false, true, elem_type>, false>(params, stream);
// // run_flash_fwd<Flash_fwd_kernel_traits<128, 128, 64, 4, true, true, elem_type>, false>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<128, 128, 32, 4, false, false, elem_type>, false>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<128, 64, 64, 4, false, false, elem_type>, false>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<128, 64, 128, 4, false, false, elem_type>, false>(params, stream);
// // 1st ones are good for H100, A100
// // 2nd one is good for A6000 bc we get slightly better occupancy
// } else {
// run_flash_fwd<Flash_fwd_kernel_traits<128, 128, 32, 4, false, false, elem_type>, true>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<128, 128, 32, 4, true, false, elem_type>, true>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<128, 128, 32, 4, true, true, elem_type>, true>(params, stream);
// // 1st one is good for H100, A100, A6000
// }
// }
template<> template<>
void run_mha_fwd_<cutlass::half_t, 128>(Flash_fwd_params &params, cudaStream_t stream) { void run_mha_fwd_<cutlass::half_t, 128>(Flash_fwd_params &params, cudaStream_t stream) {
run_mha_fwd_hdim128<cutlass::half_t>(params, stream); run_mha_fwd_hdim128<cutlass::half_t>(params, stream);
} }

View File

@ -1,17 +1,10 @@
// Copyright (c) 2023, Tri Dao. // Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation. // Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_fwd_launch_template.h" #include "flash_fwd_launch_template.h"
// template<>
// void run_mha_fwd_<cutlass::bfloat16_t, 160>(Flash_fwd_params &params, cudaStream_t stream) {
// using elem_type = cutlass::bfloat16_t;
// BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
// run_flash_fwd<Flash_fwd_kernel_traits<160, 128, 32, 4, false, false, elem_type>, Is_dropout>(params, stream);
// });
// }
template<> template<>
void run_mha_fwd_<cutlass::bfloat16_t, 160>(Flash_fwd_params &params, cudaStream_t stream) { void run_mha_fwd_<cutlass::bfloat16_t, 160>(Flash_fwd_params &params, cudaStream_t stream) {
run_mha_fwd_hdim160<cutlass::bfloat16_t>(params, stream); run_mha_fwd_hdim160<cutlass::bfloat16_t>(params, stream);
} }

View File

@ -1,27 +1,10 @@
// Copyright (c) 2023, Tri Dao. // Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation. // Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_fwd_launch_template.h" #include "flash_fwd_launch_template.h"
// template<>
// void run_mha_fwd_<cutlass::half_t, 160>(Flash_fwd_params &params, cudaStream_t stream) {
// using elem_type = cutlass::half_t;
// BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
// run_flash_fwd<Flash_fwd_kernel_traits<160, 128, 32, 4, false, false, elem_type>, Is_dropout>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<160, 128, 32, 4, false, true, elem_type>, Is_dropout>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<160, 128, 64, 4, false, false, elem_type>, Is_dropout>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<160, 64, 64, 4, false, false, elem_type>, Is_dropout>(params, stream);
// // run_flash_fwd<Flash_fwd_kernel_traits<160, 128, 64, 4, false, elem_type>>(params, stream);
// // run_flash_fwd<Flash_fwd_kernel_traits<160, 64, 128, 4, false, elem_type>>(params, stream);
// // run_flash_fwd<Flash_fwd_kernel_traits<160, 64, 64, 4, false, elem_type>>(params, stream);
// // run_flash_fwd<Flash_fwd_kernel_traits<160, 128, 64, 8, false, elem_type>>(params, stream);
// // run_flash_fwd<Flash_fwd_kernel_traits<160, 128, 128, 8, false, elem_type>>(params, stream);
// // For A6000, no-causal, 1st is fastest. causal, 4th is fastest.
// // For A100, H100, 1st is fastest.
// });
// }
template<> template<>
void run_mha_fwd_<cutlass::half_t, 160>(Flash_fwd_params &params, cudaStream_t stream) { void run_mha_fwd_<cutlass::half_t, 160>(Flash_fwd_params &params, cudaStream_t stream) {
run_mha_fwd_hdim160<cutlass::half_t>(params, stream); run_mha_fwd_hdim160<cutlass::half_t>(params, stream);
} }

View File

@ -1,16 +1,10 @@
// Copyright (c) 2023, Tri Dao. // Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation. // Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_fwd_launch_template.h" #include "flash_fwd_launch_template.h"
// template<> template<>
// void run_mha_fwd_<cutlass::bfloat16_t, 192>(Flash_fwd_params &params, cudaStream_t stream) { void run_mha_fwd_<cutlass::bfloat16_t, 192>(Flash_fwd_params &params, cudaStream_t stream) {
// using elem_type = cutlass::bfloat16_t;
// BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
// run_flash_fwd<Flash_fwd_kernel_traits<192, 64, 64, 4, false, false, elem_type>, Is_dropout>(params, stream);
// });
// }
template<> void run_mha_fwd_<cutlass::bfloat16_t, 192>(Flash_fwd_params &params, cudaStream_t stream) {
run_mha_fwd_hdim192<cutlass::bfloat16_t>(params, stream); run_mha_fwd_hdim192<cutlass::bfloat16_t>(params, stream);
} }

View File

@ -1,27 +1,10 @@
// Copyright (c) 2023, Tri Dao. // Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation. // Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_fwd_launch_template.h" #include "flash_fwd_launch_template.h"
// template<>
// void run_mha_fwd_<cutlass::half_t, 192>(Flash_fwd_params &params, cudaStream_t stream) {
// using elem_type = cutlass::half_t;
// BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
// run_flash_fwd<Flash_fwd_kernel_traits<192, 64, 64, 4, false, false, elem_type>, Is_dropout>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<192, 128, 32, 4, false, false, elem_type>, Is_dropout>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<192, 64, 32, 4, false, false, elem_type>, Is_dropout>(params, stream);
// // This one is slightly faster for causal?
// // run_flash_fwd<Flash_fwd_kernel_traits<192, 128, 64, 8, false, elem_type>>(params, stream);
// // run_flash_fwd<Flash_fwd_kernel_traits<192, 128, 32, 4, false, elem_type>>(params, stream);
// // run_flash_fwd<Flash_fwd_kernel_traits<192, 128, 64, 4, false, elem_type>>(params, stream);
// // run_flash_fwd<Flash_fwd_kernel_traits<192, 64, 128, 4, false, elem_type>>(params, stream);
// // run_flash_fwd<Flash_fwd_kernel_traits<192, 128, 128, 8, false, elem_type>>(params, stream);
// });
// // For A100 H100, 1st is faster with dropout, 3rd is faster without dropout
// // For A6000, 1st is faster when causal, 3rd is faster when not causal
// }
template<> template<>
void run_mha_fwd_<cutlass::half_t, 192>(Flash_fwd_params &params, cudaStream_t stream) { void run_mha_fwd_<cutlass::half_t, 192>(Flash_fwd_params &params, cudaStream_t stream) {
run_mha_fwd_hdim192<cutlass::half_t>(params, stream); run_mha_fwd_hdim192<cutlass::half_t>(params, stream);
} }

View File

@ -1,9 +1,10 @@
// Copyright (c) 2023, Tri Dao. // Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation. // Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_fwd_launch_template.h" #include "flash_fwd_launch_template.h"
template<> void run_mha_fwd_<cutlass::bfloat16_t, 224>(Flash_fwd_params &params, cudaStream_t stream) { template<>
void run_mha_fwd_<cutlass::bfloat16_t, 224>(Flash_fwd_params &params, cudaStream_t stream) {
run_mha_fwd_hdim224<cutlass::bfloat16_t>(params, stream); run_mha_fwd_hdim224<cutlass::bfloat16_t>(params, stream);
} }

View File

@ -1,9 +1,10 @@
// Copyright (c) 2023, Tri Dao. // Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation. // Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_fwd_launch_template.h" #include "flash_fwd_launch_template.h"
template<> void run_mha_fwd_<cutlass::half_t, 224>(Flash_fwd_params &params, cudaStream_t stream) { template<>
void run_mha_fwd_<cutlass::half_t, 224>(Flash_fwd_params &params, cudaStream_t stream) {
run_mha_fwd_hdim224<cutlass::half_t>(params, stream); run_mha_fwd_hdim224<cutlass::half_t>(params, stream);
} }

View File

@ -1,9 +1,10 @@
// Copyright (c) 2023, Tri Dao. // Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation. // Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_fwd_launch_template.h" #include "flash_fwd_launch_template.h"
template<> void run_mha_fwd_<cutlass::bfloat16_t, 256>(Flash_fwd_params &params, cudaStream_t stream) { template<>
void run_mha_fwd_<cutlass::bfloat16_t, 256>(Flash_fwd_params &params, cudaStream_t stream) {
run_mha_fwd_hdim256<cutlass::bfloat16_t>(params, stream); run_mha_fwd_hdim256<cutlass::bfloat16_t>(params, stream);
} }

View File

@ -1,9 +1,10 @@
// Copyright (c) 2023, Tri Dao. // Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation. // Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_fwd_launch_template.h" #include "flash_fwd_launch_template.h"
template<> void run_mha_fwd_<cutlass::half_t, 256>(Flash_fwd_params &params, cudaStream_t stream) { template<>
void run_mha_fwd_<cutlass::half_t, 256>(Flash_fwd_params &params, cudaStream_t stream) {
run_mha_fwd_hdim256<cutlass::half_t>(params, stream); run_mha_fwd_hdim256<cutlass::half_t>(params, stream);
} }

View File

@ -1,10 +1,10 @@
// Copyright (c) 2023, Tri Dao. // Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation. // Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_fwd_launch_template.h" #include "flash_fwd_launch_template.h"
template<> template<>
void run_mha_fwd_<cutlass::bfloat16_t, 32>(Flash_fwd_params &params, cudaStream_t stream) { void run_mha_fwd_<cutlass::bfloat16_t, 32>(Flash_fwd_params &params, cudaStream_t stream) {
run_mha_fwd_hdim32<cutlass::bfloat16_t>(params, stream); run_mha_fwd_hdim32<cutlass::bfloat16_t>(params, stream);
} }

View File

@ -1,23 +1,10 @@
// Copyright (c) 2023, Tri Dao. // Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation. // Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_fwd_launch_template.h" #include "flash_fwd_launch_template.h"
// template<>
// void run_mha_fwd_<cutlass::half_t, 32>(Flash_fwd_params &params, cudaStream_t stream) {
// using elem_type = cutlass::half_t;
// BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
// run_flash_fwd<Flash_fwd_kernel_traits<32, 128, 128, 4, false, false, elem_type>, Is_dropout>(params, stream);
// // For dropout there might be a lot of register spilling?
// // These two are very slow due to register spilling
// // run_flash_fwd<Flash_fwd_kernel_traits<32, 256, 128, 4, false, elem_type>>(params, stream);
// // run_flash_fwd<Flash_fwd_kernel_traits<32, 128, 256, 4, false, elem_type>>(params, stream);
// // This one is slightly slower
// // run_flash_fwd<Flash_fwd_kernel_traits<32, 256, 64, 4, false, elem_type>>(params, stream);
// });
// }
template<> template<>
void run_mha_fwd_<cutlass::half_t, 32>(Flash_fwd_params &params, cudaStream_t stream) { void run_mha_fwd_<cutlass::half_t, 32>(Flash_fwd_params &params, cudaStream_t stream) {
run_mha_fwd_hdim32<cutlass::half_t>(params, stream); run_mha_fwd_hdim32<cutlass::half_t>(params, stream);
} }

View File

@ -1,19 +1,10 @@
// Copyright (c) 2023, Tri Dao. // Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation. // Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_fwd_launch_template.h" #include "flash_fwd_launch_template.h"
// template<>
// void run_mha_fwd_<cutlass::bfloat16_t, 64>(Flash_fwd_params &params, cudaStream_t stream) {
// using elem_type = cutlass::bfloat16_t;
// if (params.p_dropout == 1.f) {
// run_flash_fwd<Flash_fwd_kernel_traits<64, 128, 64, 4, true, false, elem_type>, false>(params, stream);
// } else {
// run_flash_fwd<Flash_fwd_kernel_traits<64, 128, 64, 4, false, false, elem_type>, true>(params, stream);
// }
// }
template<> template<>
void run_mha_fwd_<cutlass::bfloat16_t, 64>(Flash_fwd_params &params, cudaStream_t stream) { void run_mha_fwd_<cutlass::bfloat16_t, 64>(Flash_fwd_params &params, cudaStream_t stream) {
run_mha_fwd_hdim64<cutlass::bfloat16_t>(params, stream); run_mha_fwd_hdim64<cutlass::bfloat16_t>(params, stream);
} }

View File

@ -1,26 +1,10 @@
// Copyright (c) 2023, Tri Dao. // Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation. // Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_fwd_launch_template.h" #include "flash_fwd_launch_template.h"
// template<>
// void run_mha_fwd_<cutlass::half_t, 64>(Flash_fwd_params &params, cudaStream_t stream) {
// using elem_type = cutlass::half_t;
// if (params.p_dropout == 1.f) {
// // Using 8 warps is 18% slower for seqlen=2k, 2 warps is 5% slower
// // Using block size (64 x 256) is 27% slower for seqlen=2k
// // Using block size (256 x 64) is 85% slower for seqlen=2k, because of register spilling
// run_flash_fwd<Flash_fwd_kernel_traits<64, 128, 128, 4, false, false, elem_type>, false>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<64, 128, 64, 4, true, false, elem_type>, false>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<64, 128, 64, 4, true, true, elem_type>, false>(params, stream);
// } else {
// run_flash_fwd<Flash_fwd_kernel_traits<64, 128, 64, 4, false, false, elem_type>, true>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<64, 128, 64, 4, true, true, elem_type>, true>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<64, 128, 64, 4, true, false, elem_type>, true>(params, stream);
// }
// }
template<> template<>
void run_mha_fwd_<cutlass::half_t, 64>(Flash_fwd_params &params, cudaStream_t stream) { void run_mha_fwd_<cutlass::half_t, 64>(Flash_fwd_params &params, cudaStream_t stream) {
run_mha_fwd_hdim64<cutlass::half_t>(params, stream); run_mha_fwd_hdim64<cutlass::half_t>(params, stream);
} }

View File

@ -1,17 +1,10 @@
// Copyright (c) 2023, Tri Dao. // Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation. // Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_fwd_launch_template.h" #include "flash_fwd_launch_template.h"
// template<>
// void run_mha_fwd_<cutlass::bfloat16_t, 96>(Flash_fwd_params &params, cudaStream_t stream) {
// using elem_type = cutlass::bfloat16_t;
// BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
// run_flash_fwd<Flash_fwd_kernel_traits<96, 128, 64, 4, true, false, elem_type>, Is_dropout>(params, stream);
// });
// }
template<> template<>
void run_mha_fwd_<cutlass::bfloat16_t, 96>(Flash_fwd_params &params, cudaStream_t stream) { void run_mha_fwd_<cutlass::bfloat16_t, 96>(Flash_fwd_params &params, cudaStream_t stream) {
run_mha_fwd_hdim96<cutlass::bfloat16_t>(params, stream); run_mha_fwd_hdim96<cutlass::bfloat16_t>(params, stream);
} }

View File

@ -1,23 +1,10 @@
// Copyright (c) 2023, Tri Dao. // Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation. // Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_fwd_launch_template.h" #include "flash_fwd_launch_template.h"
// template<> template<>
// void run_mha_fwd_<cutlass::half_t, 96>(Flash_fwd_params &params, cudaStream_t stream) { void run_mha_fwd_<cutlass::half_t, 96>(Flash_fwd_params &params, cudaStream_t stream) {
// using elem_type = cutlass::half_t;
// BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
// run_flash_fwd<Flash_fwd_kernel_traits<96, 128, 64, 4, true, false, elem_type>, Is_dropout>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<96, 128, 64, 4, true, true, elem_type>, Is_dropout>(params, stream);
// // This 3rd one is good for H100, and A100, A6000
// run_flash_fwd<Flash_fwd_kernel_traits<96, 128, 64, 4, false, false, elem_type>, Is_dropout>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<96, 128, 64, 4, false, true, elem_type>, Is_dropout>(params, stream);
// // These two are always slower
// // run_flash_fwd<Flash_fwd_kernel_traits<96, 128, 128, 4, true, elem_type>>(params, stream);
// // run_flash_fwd<Flash_fwd_kernel_traits<96, 64, 128, 4, true, elem_type>>(params, stream);
// });
// }
template<> void run_mha_fwd_<cutlass::half_t, 96>(Flash_fwd_params &params, cudaStream_t stream) {
run_mha_fwd_hdim96<cutlass::half_t>(params, stream); run_mha_fwd_hdim96<cutlass::half_t>(params, stream);
} }

View File

@ -4,20 +4,18 @@
#pragma once #pragma once
#include <cmath>
#include <cute/algorithm/copy.hpp> #include <cute/algorithm/copy.hpp>
#include <cute/algorithm/gemm.hpp>
#include <cutlass/cutlass.h> #include <cutlass/cutlass.h>
#include <cutlass/array.h> #include <cutlass/array.h>
#include <cutlass/numeric_types.h> #include <cutlass/numeric_types.h>
#include <cutlass/numeric_conversion.h>
#include "block_info.h" #include "block_info.h"
#include "kernel_traits.h" #include "kernel_traits.h"
#include "utils.h" #include "utils.h"
#include "softmax.h" #include "softmax.h"
#include "philox.cuh"
#include "alibi.h"
namespace flash { namespace flash {
@ -25,49 +23,6 @@ using namespace cute;
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
template <int MMA_M,
class... Args,
class TiledMMA>
CUTE_HOST_DEVICE
auto
make_tiled_copy_A_warpcontiguousM(Copy_Atom<Args...> const& copy_atom,
TiledMMA const& tiled_mma) {
using TileShape_MNK = typename TiledMMA::TiledShape_MNK;
using AtomShape_MNK = typename TiledMMA::AtomShape_MNK;
constexpr int AtomShape_M = decltype(size<0>(AtomShape_MNK{}))::value;
constexpr int kNWarps = decltype(size<0>(TileShape_MNK{}))::value / AtomShape_M;
constexpr int MMAStride_M = MMA_M * AtomShape_M;
auto t = make_tile(Layout<Shape<Int<AtomShape_M>, Int<kNWarps>>,
Stride<_1, Int<MMAStride_M>> >{},
make_layout(size<2>(TileShape_MNK{})));
// if (cute::thread0()) {printf("make_tiled_copy_A_warpcontiguousM "); print(t); printf("\n"); }
return make_tiled_copy_impl(copy_atom, tiled_mma.get_layoutA_TV(), t);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template <int MMA_M,
class... Args,
class TiledMMA>
CUTE_HOST_DEVICE
auto
make_tiled_copy_C_warpcontiguousM(Copy_Atom<Args...> const& copy_atom,
TiledMMA const& tiled_mma) {
using TileShape_MNK = typename TiledMMA::TiledShape_MNK;
using AtomShape_MNK = typename TiledMMA::AtomShape_MNK;
constexpr int AtomShape_M = decltype(size<0>(AtomShape_MNK{}))::value;
constexpr int kNWarps = decltype(size<0>(TileShape_MNK{}))::value / AtomShape_M;
constexpr int MMAStride_M = MMA_M * AtomShape_M;
auto t = make_tile(Layout<Shape<Int<AtomShape_M>, Int<kNWarps>>,
Stride<_1, Int<MMAStride_M>> >{},
// TODO: Shouldn't this be size<1>?
make_layout(size<2>(TileShape_MNK{})));
// if (cute::thread0()) {printf("make_tiled_copy_C_warpcontiguousM "); print(t); printf("\n"); }
return make_tiled_copy_impl(copy_atom, tiled_mma.get_layoutC_TV(), t);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<bool Is_first, bool Check_inf=false, typename Tensor0, typename Tensor1, typename Tensor2> template<bool Is_first, bool Check_inf=false, typename Tensor0, typename Tensor1, typename Tensor2>
inline __device__ void softmax_rescale_o(Tensor0 &scores, Tensor1 &scores_max, Tensor1 &scores_sum, inline __device__ void softmax_rescale_o(Tensor0 &scores, Tensor1 &scores_max, Tensor1 &scores_sum,
Tensor2 &acc_o, float softmax_scale_log2) { Tensor2 &acc_o, float softmax_scale_log2) {
@ -77,7 +32,7 @@ inline __device__ void softmax_rescale_o(Tensor0 &scores, Tensor1 &scores_max, T
flash::reduce_sum(scores, scores_sum); flash::reduce_sum(scores, scores_sum);
} else { } else {
Tensor scores_max_prev = make_fragment_like(scores_max); Tensor scores_max_prev = make_fragment_like(scores_max);
copy(scores_max, scores_max_prev); cute::copy(scores_max, scores_max_prev);
flash::template reduce_max</*zero_init=*/false>(scores, scores_max); flash::template reduce_max</*zero_init=*/false>(scores, scores_max);
// Reshape acc_o from (MMA=4, MMA_M, MMA_K) to (nrow=(2, MMA_M), ncol=(2, MMA_K)) // Reshape acc_o from (MMA=4, MMA_M, MMA_K) to (nrow=(2, MMA_M), ncol=(2, MMA_K))
Tensor acc_o_rowcol = make_tensor(acc_o.data(), flash::convert_layout_acc_rowcol(acc_o.layout())); Tensor acc_o_rowcol = make_tensor(acc_o.data(), flash::convert_layout_acc_rowcol(acc_o.layout()));
@ -103,23 +58,22 @@ inline __device__ void softmax_rescale_o(Tensor0 &scores, Tensor1 &scores_max, T
template<typename Engine0, typename Layout0, typename Engine1, typename Layout1, typename TiledCopy> template<typename Engine0, typename Layout0, typename Engine1, typename Layout1, typename TiledCopy>
inline __device__ void write_softmax_to_gmem( inline __device__ void write_softmax_to_gmem(
Tensor<Engine0, Layout0> const &tOrP, Tensor<Engine1, Layout1> &tPgP, TiledCopy gmem_thr_copy_P Tensor<Engine0, Layout0> const &tOrP, Tensor<Engine1, Layout1> &tPgP, TiledCopy gmem_tiled_copy_P
) { ) {
// Reshape tOrP from (8, MMA_M, MMA_N) to (8, MMA_M * MMA_N) // Reshape tOrP from (8, MMA_M, MMA_N) to (8, MMA_M * MMA_N)
Layout l = tOrP.layout(); Layout l = tOrP.layout();
Tensor tPrP = make_tensor(tOrP.data(), make_layout(get<0>(l), make_layout(get<1>(l), get<2>(l)))); Tensor tPrP = make_tensor(tOrP.data(), make_layout(get<0>(l), make_layout(get<1>(l), get<2>(l))));
CUTE_STATIC_ASSERT_V(size<2>(tPgP) == _1{}); CUTE_STATIC_ASSERT_V(size<2>(tPgP) == _1{});
// TODO(laurent): reactivate the following CUTE_STATIC_ASSERT_V(size<1>(tPrP) == size<1>(tPgP));
// CUTE_STATIC_ASSERT_V(size<1>(tPrP) == size<1>(tPgP));
#pragma unroll #pragma unroll
for (int mi = 0; mi < size<1>(tPrP); ++mi) { for (int mi = 0; mi < size<1>(tPrP); ++mi) {
copy(gmem_thr_copy_P, tPrP(_, mi), tPgP(_, mi, 0)); cute::copy(gmem_tiled_copy_P, tPrP(_, mi), tPgP(_, mi, 0));
} }
}; };
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_even_N, bool Is_even_K, bool Return_softmax, typename Params> template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Return_softmax, typename Params>
inline __device__ void compute_attn_1rowblock(const Params &params, const int bidb, const int bidh, const int m_block) { inline __device__ void compute_attn_1rowblock(const Params &params, const int bidb, const int bidh, const int m_block) {
using Element = typename Kernel_traits::Element; using Element = typename Kernel_traits::Element;
@ -138,16 +92,65 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
constexpr int kNWarps = Kernel_traits::kNWarps; constexpr int kNWarps = Kernel_traits::kNWarps;
constexpr int MMA_M = kBlockM / decltype(size<0>(typename Kernel_traits::TiledMma::TiledShape_MNK{}))::value; constexpr int MMA_M = kBlockM / decltype(size<0>(typename Kernel_traits::TiledMma::TiledShape_MNK{}))::value;
const BlockInfo</*Varlen=*/!Is_even_N> binfo(params, bidb); const BlockInfo</*Varlen=*/!Is_even_MN> binfo(params, bidb);
if (m_block * kBlockM >= binfo.actual_seqlen_q || binfo.actual_seqlen_k == 0) return; if (m_block * kBlockM >= binfo.actual_seqlen_q) return;
const int n_block_min = !Is_local ? 0 : std::max(0, (m_block * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q - params.window_size_left) / kBlockN);
int n_block_max = cute::ceil_div(binfo.actual_seqlen_k, kBlockN); int n_block_max = cute::ceil_div(binfo.actual_seqlen_k, kBlockN);
if (Is_causal) { if (Is_causal || Is_local) {
n_block_max = std::min(n_block_max, cute::ceil_div((m_block + 1) * kBlockM, kBlockN)); n_block_max = std::min(n_block_max,
cute::ceil_div((m_block + 1) * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q + params.window_size_right, kBlockN));
// if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) { // if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) {
// printf("m_block = %d, n_block_max = %d\n", m_block, n_block_max); // printf("m_block = %d, n_block_max = %d\n", m_block, n_block_max);
// } // }
} }
// We exit early and write 0 to gO and gLSE. This also covers the case where actual_seqlen_k == 0.
// Otherwise we might read OOB elements from gK and gV.
if ((Is_causal || Is_local || !Is_even_MN) && n_block_max <= n_block_min) {
// Save seed and offset for backward. If we don't have this here, the 0-th thread block might
// exit early and no one saves the rng state.
// if (Is_dropout && blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0 && tidx == 0) {
// auto seeds = at::cuda::philox::unpack(params.philox_args);
// params.rng_state[0] = std::get<0>(seeds);
// params.rng_state[1] = std::get<1>(seeds);
// params.rng_state[0] = 0;
// params.rng_state[1] = 0;
// }
const index_t row_offset_o = binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb)
+ m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride;
const index_t row_offset_lse = (bidb * params.h + bidh) * params.seqlen_q + m_block * kBlockM;
Tensor gO = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.o_ptr) + row_offset_o),
Shape<Int<kBlockM>, Int<kHeadDim>>{},
make_stride(params.o_row_stride, _1{}));
Tensor gLSE = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.softmax_lse_ptr) + row_offset_lse),
Shape<Int<kBlockM>>{}, Stride<_1>{});
typename Kernel_traits::GmemTiledCopyO gmem_tiled_copy_O;
auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(tidx);
Tensor tOgO = gmem_thr_copy_O.partition_D(gO);
Tensor tOrO = make_tensor<Element>(shape(tOgO));
clear(tOrO);
// Construct identity layout for sO
Tensor cO = make_identity_tensor(make_shape(size<0>(gO), size<1>(gO))); // (BLK_M,BLK_K) -> (blk_m,blk_k)
// Repeat the partitioning with identity layouts
Tensor tOcO = gmem_thr_copy_O.partition_D(cO);
Tensor tOpO = make_tensor<bool>(make_shape(size<2>(tOgO)));
if (!Is_even_K) {
#pragma unroll
for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(0, 0, k)) < params.d; }
}
// Clear_OOB_K must be false since we don't want to write zeros to gmem
flash::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
gmem_tiled_copy_O, tOrO, tOgO, tOcO, tOpO, binfo.actual_seqlen_q - m_block * kBlockM
);
#pragma unroll
for (int m = 0; m < size<1>(tOgO); ++m) {
const int row = get<0>(tOcO(0, m, 0));
if (row < binfo.actual_seqlen_q - m_block * kBlockM && get<1>(tOcO(0, m, 0)) == 0) { gLSE(row) = INFINITY; }
}
return;
}
// if (tidx == 0) { printf("m_block = %d, n_block_min = %d, n_block_max = %d\n", m_block, n_block_min, n_block_max); }
// We iterate over the blocks in reverse order. This is because the last block is the only one // We iterate over the blocks in reverse order. This is because the last block is the only one
// that needs masking when we read K and V from global memory. Moreover, iterating in reverse // that needs masking when we read K and V from global memory. Moreover, iterating in reverse
@ -185,8 +188,10 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
Tensor sVt = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposed{}); Tensor sVt = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposed{});
Tensor sVtNoSwizzle = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposedNoSwizzle{}); Tensor sVtNoSwizzle = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposedNoSwizzle{});
auto gmem_thr_copy_QKV = typename Kernel_traits::GmemTiledCopyQKV{}.get_thread_slice(tidx); typename Kernel_traits::GmemTiledCopyQKV gmem_tiled_copy_QKV;
auto gmem_thr_copy_P = typename Kernel_traits::GmemTiledCopyP{}.get_thread_slice(tidx); auto gmem_thr_copy_QKV = gmem_tiled_copy_QKV.get_thread_slice(tidx);
typename Kernel_traits::GmemTiledCopyP gmem_tiled_copy_P;
auto gmem_thr_copy_P = gmem_tiled_copy_P.get_thread_slice(tidx);
Tensor tQgQ = gmem_thr_copy_QKV.partition_S(gQ); Tensor tQgQ = gmem_thr_copy_QKV.partition_S(gQ);
Tensor tQsQ = gmem_thr_copy_QKV.partition_D(sQ); Tensor tQsQ = gmem_thr_copy_QKV.partition_D(sQ);
@ -208,16 +213,18 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
// Copy Atom retiling // Copy Atom retiling
// //
auto smem_thr_copy_Q = make_tiled_copy_A(typename Kernel_traits::SmemCopyAtom{}, tiled_mma).get_thread_slice(tidx); auto smem_tiled_copy_Q = make_tiled_copy_A(typename Kernel_traits::SmemCopyAtom{}, tiled_mma);
// auto smem_thr_copy_Q = make_tiled_copy_A_warpcontiguousM<MMA_M>(typename Kernel_traits::SmemCopyAtom{}, tiled_mma).get_thread_slice(tidx); auto smem_thr_copy_Q = smem_tiled_copy_Q.get_thread_slice(tidx);
// if (cute::thread0()) {smem_thr_copy_Q.print_all();} // if (cute::thread0()) {smem_thr_copy_Q.print_all();}
Tensor tSsQ = smem_thr_copy_Q.partition_S(sQ); Tensor tSsQ = smem_thr_copy_Q.partition_S(sQ);
// if (cute::thread0()) {print(tSsQ.layout()); printf("\n");} // if (cute::thread0()) {print(tSsQ.layout()); printf("\n");}
auto smem_thr_copy_K = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtom{}, tiled_mma).get_thread_slice(tidx); auto smem_tiled_copy_K = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtom{}, tiled_mma);
auto smem_thr_copy_K = smem_tiled_copy_K.get_thread_slice(tidx);
Tensor tSsK = smem_thr_copy_K.partition_S(sK); Tensor tSsK = smem_thr_copy_K.partition_S(sK);
auto smem_thr_copy_V = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtomTransposed{}, tiled_mma).get_thread_slice(tidx); auto smem_tiled_copy_V = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtomTransposed{}, tiled_mma);
auto smem_thr_copy_V = smem_tiled_copy_V.get_thread_slice(tidx);
Tensor tOsVt = smem_thr_copy_V.partition_S(sVt); Tensor tOsVt = smem_thr_copy_V.partition_S(sVt);
// TODO: this might need to change if we change the mma instruction in SM70 // TODO: this might need to change if we change the mma instruction in SM70
@ -268,8 +275,8 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
Tensor tQrQ = make_fragment_like(tQgQ); Tensor tQrQ = make_fragment_like(tQgQ);
// We don't need to clear the sQ smem tiles since we'll only write out the valid outputs // We don't need to clear the sQ smem tiles since we'll only write out the valid outputs
flash::copy</*Is_even_MN=*/false, Is_even_K>(gmem_thr_copy_QKV, tQgQ, tQsQ, tQcQ, tQpQ, flash::copy<Is_even_MN, Is_even_K>(gmem_tiled_copy_QKV, tQgQ, tQsQ, tQcQ, tQpQ,
binfo.actual_seqlen_q - m_block * kBlockM); binfo.actual_seqlen_q - m_block * kBlockM);
if (Kernel_traits::Is_Q_in_regs) { cute::cp_async_fence(); } if (Kernel_traits::Is_Q_in_regs) { cute::cp_async_fence(); }
// // Copy rmem to smem // // Copy rmem to smem
@ -285,14 +292,14 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
__syncthreads(); __syncthreads();
Tensor tSrQ_copy_view = smem_thr_copy_Q.retile_D(tSrQ); Tensor tSrQ_copy_view = smem_thr_copy_Q.retile_D(tSrQ);
CUTE_STATIC_ASSERT_V(size<1>(tSsQ) == size<1>(tSrQ_copy_view)); // M CUTE_STATIC_ASSERT_V(size<1>(tSsQ) == size<1>(tSrQ_copy_view)); // M
copy(smem_thr_copy_Q, tSsQ, tSrQ_copy_view); cute::copy(smem_tiled_copy_Q, tSsQ, tSrQ_copy_view);
__syncthreads(); __syncthreads();
} }
int n_block = n_block_max - 1; int n_block = n_block_max - 1;
// We don't need to clear the sK smem tiles since we'll mask out the scores anyway. // We don't need to clear the sK smem tiles since we'll mask out the scores anyway.
flash::copy<Is_even_N, Is_even_K>(gmem_thr_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV, flash::copy<Is_even_MN, Is_even_K>(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV,
binfo.actual_seqlen_k - n_block * kBlockN); binfo.actual_seqlen_k - n_block * kBlockN);
cute::cp_async_fence(); cute::cp_async_fence();
// if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z < 2) { print(tKgK); } // if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z < 2) { print(tKgK); }
// __syncthreads(); // __syncthreads();
@ -302,7 +309,7 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
__syncthreads(); __syncthreads();
Tensor tSrQ_copy_view = smem_thr_copy_Q.retile_D(tSrQ); Tensor tSrQ_copy_view = smem_thr_copy_Q.retile_D(tSrQ);
CUTE_STATIC_ASSERT_V(size<1>(tSsQ) == size<1>(tSrQ_copy_view)); // M CUTE_STATIC_ASSERT_V(size<1>(tSsQ) == size<1>(tSrQ_copy_view)); // M
copy(smem_thr_copy_Q, tSsQ, tSrQ_copy_view); cute::copy(smem_tiled_copy_Q, tSsQ, tSrQ_copy_view);
} }
// auto seeds = at::cuda::philox::unpack(params.philox_args); // auto seeds = at::cuda::philox::unpack(params.philox_args);
@ -313,13 +320,19 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
clear(acc_o); clear(acc_o);
float alibi_slope = !Has_alibi ? 0.0f : reinterpret_cast<float *>(params.alibi_slopes_ptr)[bidb * params.alibi_slopes_batch_stride + bidh] / params.scale_softmax;
// For performance reason, we separate out two kinds of iterations: // For performance reason, we separate out two kinds of iterations:
// those that need masking on S, and those that don't. // those that need masking on S, and those that don't.
// We need masking on S for the very last block when K and V has length not multiple of kBlockN. // We need masking on S for the very last block when K and V has length not multiple of kBlockN.
// We also need masking on S if it's causal, for the last ceil_div(kBlockM, kBlockN) blocks. // We also need masking on S if it's causal, for the last ceil_div(kBlockM, kBlockN) blocks.
// We will have at least 1 "masking" iteration. // We will have at least 1 "masking" iteration.
constexpr int n_masking_steps = Is_causal ? cute::ceil_div(kBlockM, kBlockN) : 1; // If not even_N, then seqlen_k might end in the middle of a block. In that case we need to
// mask 2 blocks (e.g. when kBlockM == kBlockN), not just 1.
constexpr int n_masking_steps = (!Is_causal && !Is_local)
? 1
: ((Is_even_MN && Is_causal) ? cute::ceil_div(kBlockM, kBlockN) : cute::ceil_div(kBlockM, kBlockN) + 1);
#pragma unroll #pragma unroll
for (int masking_step = 0; masking_step < n_masking_steps; ++masking_step, --n_block) { for (int masking_step = 0; masking_step < n_masking_steps; ++masking_step, --n_block) {
Tensor acc_s = partition_fragment_C(tiled_mma, Shape<Int<kBlockM>, Int<kBlockN>>{}); // (MMA=4, MMA_M, MMA_N) Tensor acc_s = partition_fragment_C(tiled_mma, Shape<Int<kBlockM>, Int<kBlockN>>{}); // (MMA=4, MMA_M, MMA_N)
@ -330,28 +343,42 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
// Advance gV // Advance gV
if (masking_step > 0) { if (masking_step > 0) {
tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride)); tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride));
flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_thr_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV); flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV);
} else { } else {
// Clear the smem tiles to account for predicated off loads // Clear the smem tiles to account for predicated off loads
flash::copy<Is_even_N, Is_even_K, /*Clear_OOB_MN=*/true>( flash::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/true>(
gmem_thr_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN
); );
} }
cute::cp_async_fence(); cute::cp_async_fence();
flash::gemm</*A_in_regs=*/Kernel_traits::Is_Q_in_regs>( flash::gemm</*A_in_regs=*/Kernel_traits::Is_Q_in_regs>(
acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_thr_copy_Q, smem_thr_copy_K acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_tiled_copy_Q, smem_tiled_copy_K,
smem_thr_copy_Q, smem_thr_copy_K
); );
// if (cute::thread0()) { print(acc_s); } // if (cute::thread0()) { print(acc_s); }
// Reshape acc_s from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N)) // Reshape acc_s from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N))
Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout())); Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout()));
// if (cute::thread0()) { print(scores); } // if (cute::thread0()) { print_tensor(scores); }
// We don't put the masking before the matmul S = Q K^T because we don't clear sK // We don't put the masking before the matmul S = Q K^T because we don't clear sK
// for rows outside actual_seqlen_k. So those rows could have Inf / NaN, and the matmul // for rows outside actual_seqlen_k. So those rows could have Inf / NaN, and the matmul
// can produce Inf / NaN. // can produce Inf / NaN.
if (!Is_causal) {
if (!Is_even_N) { flash::apply_mask(scores, binfo.actual_seqlen_k - n_block * kBlockN); } if (Has_alibi) {
flash::apply_alibi<Is_causal>(
scores,
n_block * kBlockN,
binfo.actual_seqlen_k,
m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4,
binfo.actual_seqlen_q,
kNWarps * 16,
alibi_slope
);
}
if (!Is_causal && !Is_local) {
if (!Is_even_MN) { flash::apply_mask(scores, binfo.actual_seqlen_k - n_block * kBlockN); }
} else { } else {
// Tensor caccS = make_identity_tensor(Shape<Int<kBlockM>, Int<kBlockN>>{}); // (BLK_M,BLK_N) -> (blk_m,blk_n) // Tensor caccS = make_identity_tensor(Shape<Int<kBlockM>, Int<kBlockN>>{}); // (BLK_M,BLK_N) -> (blk_m,blk_n)
// Tensor taccScS = thr_mma.partition_C(caccS); // (MMA,MMA_M,MMA_N) // Tensor taccScS = thr_mma.partition_C(caccS); // (MMA,MMA_M,MMA_N)
@ -364,20 +391,24 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
// Idk why it's get<1> and not get<0> of the stride. // Idk why it's get<1> and not get<0> of the stride.
// if (cute::thread0()) { print(idx_row.layout()); print(stride<1>(idx_row)); printf("stride = %d \n", get<1>(stride<1>(idx_row))); } // if (cute::thread0()) { print(idx_row.layout()); print(stride<1>(idx_row)); printf("stride = %d \n", get<1>(stride<1>(idx_row))); }
// I can't get the stride from idx_row // I can't get the stride from idx_row
flash::apply_mask_causal(scores, n_block * kBlockN, binfo.actual_seqlen_k, flash::apply_mask_local</*HasWSLeft=*/Is_local>(
// m_block * kBlockM + get<0>(idx_row(0)), scores, n_block * kBlockN, binfo.actual_seqlen_k,
m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, // m_block * kBlockM + get<0>(idx_row(0)),
kNWarps * 16); m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4,
// m_block * kBlockM + (tidx / 32) * 16, kNWarps * 16); binfo.actual_seqlen_q, kNWarps * 16,
// m_block * kBlockM + (tidx / 32) * (kBlockM / kNWarps), 16); params.window_size_left, params.window_size_right
// m_block * kBlockM + (tidx / 32) * 16, kNWarps * 16
// m_block * kBlockM + (tidx / 32) * (kBlockM / kNWarps), 16
);
// if (cute::thread0()) { print_tensor(scores); }
} }
flash::cp_async_wait<0>(); flash::cp_async_wait<0>();
__syncthreads(); __syncthreads();
if (n_block > 0) { if (n_block > n_block_min) {
// Advance gK // Advance gK
tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride)); tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride));
flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_thr_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV); flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV);
// This cp_async_fence needs to be in the if block, otherwise the synchronization // This cp_async_fence needs to be in the if block, otherwise the synchronization
// isn't right and we get race conditions. // isn't right and we get race conditions.
cute::cp_async_fence(); cute::cp_async_fence();
@ -385,24 +416,24 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
// TODO: when we have key_padding_mask we'll need to Check_inf // TODO: when we have key_padding_mask we'll need to Check_inf
masking_step == 0 masking_step == 0
? softmax_rescale_o</*Is_first=*/true, /*Check_inf=*/Is_causal>(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2) ? softmax_rescale_o</*Is_first=*/true, /*Check_inf=*/Is_causal || Is_local>(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2)
: softmax_rescale_o</*Is_first=*/false, /*Check_inf=*/Is_causal>(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2); : softmax_rescale_o</*Is_first=*/false, /*Check_inf=*/Is_causal || Is_local>(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2);
// Convert scores from fp32 to fp16/bf16 // Convert scores from fp32 to fp16/bf16
Tensor rP = flash::convert_type<Element>(scores); Tensor rP = flash::convert_type<Element>(scores);
// Reshape rP from (nrow=(2, MMA_M), ncol=(2, MMA_N)) to ((2, 2, 2), MMA_M, MMA_N / 2) // Reshape rP from (nrow=(2, MMA_M), ncol=(2, MMA_N)) to ((2, 2, 2), MMA_M, MMA_N / 2)
// if using m16n8k16 or ((2, 2, 1), MMA_M, MMA_N) if using m16n8k8. // if using m16n8k16 or ((2, 2, 1), MMA_M, MMA_N) if using m16n8k8.
Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_rowcol_Aregs<Kernel_traits::TiledMma>(rP.layout())); Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_rowcol_Aregs<Kernel_traits::TiledMma>(rP.layout()));
uint32_t block_row_idx = m_block * (kBlockM / 16) + tidx / 32; int block_row_idx = m_block * (kBlockM / 16) + tidx / 32;
uint32_t block_col_idx = n_block * (kBlockN / 32); int block_col_idx = n_block * (kBlockN / 32);
if (Return_softmax) { if (Return_softmax) {
Tensor tOrP_copy = make_fragment_like(tOrP); Tensor tOrP_copy = make_fragment_like(tOrP);
copy(tOrP, tOrP_copy); cute::copy(tOrP, tOrP_copy);
flash::apply_dropout</*encode_dropout_in_sign_bit=*/true>( flash::apply_dropout</*encode_dropout_in_sign_bit=*/true>(
tOrP_copy, params.p_dropout_in_uint8_t, seed, offset, tOrP_copy, params.p_dropout_in_uint8_t, seed, offset,
block_row_idx, block_col_idx, kNWarps block_row_idx, block_col_idx, kNWarps
); );
flash::write_softmax_to_gmem(tOrP_copy, tPgP, gmem_thr_copy_P); flash::write_softmax_to_gmem(tOrP_copy, tPgP, gmem_tiled_copy_P);
tPgP.data() = tPgP.data() + (-kBlockN); tPgP.data() = tPgP.data() + (-kBlockN);
} }
if (Is_dropout) { if (Is_dropout) {
@ -411,37 +442,38 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
} }
// if (cute::thread0()) { print(tOrP); } // if (cute::thread0()) { print(tOrP); }
flash::gemm_A_in_regs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_thr_copy_V); flash::gemm_A_in_regs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V);
// if (cute::thread0()) { print(scores); } // if (cute::thread0()) { print(scores); }
// This check is at the end of the loop since we always have at least 1 iteration // This check is at the end of the loop since we always have at least 1 iteration
if (n_masking_steps > 1 && n_block <= 0) { if (n_masking_steps > 1 && n_block <= n_block_min) {
--n_block; --n_block;
break; break;
} }
} }
// These are the iterations where we don't need masking on S // These are the iterations where we don't need masking on S
for (; n_block >= 0; --n_block) { for (; n_block >= n_block_min; --n_block) {
Tensor acc_s = partition_fragment_C(tiled_mma, Shape<Int<kBlockM>, Int<kBlockN>>{}); // (MMA=4, MMA_M, MMA_N) Tensor acc_s = partition_fragment_C(tiled_mma, Shape<Int<kBlockM>, Int<kBlockN>>{}); // (MMA=4, MMA_M, MMA_N)
clear(acc_s); clear(acc_s);
flash::cp_async_wait<0>(); flash::cp_async_wait<0>();
__syncthreads(); __syncthreads();
// Advance gV // Advance gV
tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride)); tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride));
flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_thr_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV); flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV);
cute::cp_async_fence(); cute::cp_async_fence();
flash::gemm</*A_in_regs=*/Kernel_traits::Is_Q_in_regs>( flash::gemm</*A_in_regs=*/Kernel_traits::Is_Q_in_regs>(
acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_thr_copy_Q, smem_thr_copy_K acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_tiled_copy_Q, smem_tiled_copy_K,
smem_thr_copy_Q, smem_thr_copy_K
); );
flash::cp_async_wait<0>(); flash::cp_async_wait<0>();
__syncthreads(); __syncthreads();
if (n_block > 0) { if (n_block > n_block_min) {
// Advance gK // Advance gK
tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride)); tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride));
flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_thr_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV); flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV);
// This cp_async_fence needs to be in the if block, otherwise the synchronization // This cp_async_fence needs to be in the if block, otherwise the synchronization
// isn't right and we get race conditions. // isn't right and we get race conditions.
cute::cp_async_fence(); cute::cp_async_fence();
@ -449,22 +481,44 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
// Reshape acc_s from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N)) // Reshape acc_s from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N))
Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout())); Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout()));
softmax_rescale_o</*Is_first=*/false>(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2);
if (Has_alibi) {
flash::apply_alibi<Is_causal>(
scores,
n_block * kBlockN,
binfo.actual_seqlen_k,
m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4,
binfo.actual_seqlen_q,
kNWarps * 16,
alibi_slope
);
}
if (Is_local && n_block * kBlockN < (m_block + 1) * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q + params.window_size_right) {
flash::apply_mask_local(
scores, n_block * kBlockN, binfo.actual_seqlen_k,
m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4,
binfo.actual_seqlen_q, kNWarps * 16,
params.window_size_left, params.window_size_right
);
}
softmax_rescale_o</*Is_first=*/false, /*Check_inf=*/Is_local>(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2);
Tensor rP = flash::convert_type<Element>(scores); Tensor rP = flash::convert_type<Element>(scores);
// Reshape rP from (nrow=(2, MMA_M), ncol=(2, MMA_N)) to ((2, 2, 2), MMA_M, MMA_N / 2) // Reshape rP from (nrow=(2, MMA_M), ncol=(2, MMA_N)) to ((2, 2, 2), MMA_M, MMA_N / 2)
// if using m16n8k16 or ((2, 2, 1), MMA_M, MMA_N) if using m16n8k8. // if using m16n8k16 or ((2, 2, 1), MMA_M, MMA_N) if using m16n8k8.
Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_rowcol_Aregs<Kernel_traits::TiledMma>(rP.layout())); Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_rowcol_Aregs<Kernel_traits::TiledMma>(rP.layout()));
uint32_t block_row_idx = m_block * (kBlockM / 16) + tidx / 32; int block_row_idx = m_block * (kBlockM / 16) + tidx / 32;
uint32_t block_col_idx = n_block * (kBlockN / 32); int block_col_idx = n_block * (kBlockN / 32);
if (Return_softmax) { if (Return_softmax) {
Tensor tOrP_copy = make_fragment_like(tOrP); Tensor tOrP_copy = make_fragment_like(tOrP);
copy(tOrP, tOrP_copy); cute::copy(tOrP, tOrP_copy);
flash::apply_dropout</*encode_dropout_in_sign_bit=*/true>( flash::apply_dropout</*encode_dropout_in_sign_bit=*/true>(
tOrP_copy, params.p_dropout_in_uint8_t, seed, offset, tOrP_copy, params.p_dropout_in_uint8_t, seed, offset,
block_row_idx, block_col_idx, kNWarps block_row_idx, block_col_idx, kNWarps
); );
flash::write_softmax_to_gmem(tOrP_copy, tPgP, gmem_thr_copy_P); flash::write_softmax_to_gmem(tOrP_copy, tPgP, gmem_tiled_copy_P);
tPgP.data() = tPgP.data() + (-kBlockN); tPgP.data() = tPgP.data() + (-kBlockN);
} }
if (Is_dropout) { if (Is_dropout) {
@ -472,7 +526,7 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
block_row_idx, block_col_idx, kNWarps); block_row_idx, block_col_idx, kNWarps);
} }
flash::gemm_A_in_regs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_thr_copy_V); flash::gemm_A_in_regs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V);
} }
// Epilogue // Epilogue
@ -496,15 +550,15 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
Tensor rO = flash::convert_type<Element>(acc_o); Tensor rO = flash::convert_type<Element>(acc_o);
Tensor sO = make_tensor(sQ.data(), typename Kernel_traits::SmemLayoutO{}); // (SMEM_M,SMEM_N) Tensor sO = make_tensor(sQ.data(), typename Kernel_traits::SmemLayoutO{}); // (SMEM_M,SMEM_N)
// Partition sO to match the accumulator partitioning // Partition sO to match the accumulator partitioning
auto smem_thr_copy_O = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomO{}, tiled_mma).get_thread_slice(tidx); auto smem_tiled_copy_O = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomO{}, tiled_mma);
// auto smem_thr_copy_O = make_tiled_copy_C_warpcontiguousM<MMA_M>(typename Kernel_traits::SmemCopyAtomO{}, tiled_mma).get_thread_slice(tidx); auto smem_thr_copy_O = smem_tiled_copy_O.get_thread_slice(tidx);
Tensor taccOrO = smem_thr_copy_O.retile_S(rO); // ((Atom,AtomNum), MMA_M, MMA_N) Tensor taccOrO = smem_thr_copy_O.retile_S(rO); // ((Atom,AtomNum), MMA_M, MMA_N)
Tensor taccOsO = smem_thr_copy_O.partition_D(sO); // ((Atom,AtomNum),PIPE_M,PIPE_N) Tensor taccOsO = smem_thr_copy_O.partition_D(sO); // ((Atom,AtomNum),PIPE_M,PIPE_N)
// sO has the same size as sQ, so we don't need to sync here. // sO has the same size as sQ, so we don't need to sync here.
if (Kernel_traits::Share_Q_K_smem) { __syncthreads(); } if (Kernel_traits::Share_Q_K_smem) { __syncthreads(); }
copy(smem_thr_copy_O, taccOrO, taccOsO); cute::copy(smem_tiled_copy_O, taccOrO, taccOsO);
const index_t row_offset_o = binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb) const index_t row_offset_o = binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb)
+ m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride; + m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride;
@ -515,14 +569,15 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
Tensor gLSE = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.softmax_lse_ptr) + row_offset_lse), Tensor gLSE = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.softmax_lse_ptr) + row_offset_lse),
Shape<Int<kBlockM>>{}, Stride<_1>{}); Shape<Int<kBlockM>>{}, Stride<_1>{});
auto gmem_thr_copy_O = typename Kernel_traits::GmemTiledCopyO{}.get_thread_slice(tidx); typename Kernel_traits::GmemTiledCopyO gmem_tiled_copy_O;
auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(tidx);
Tensor tOsO = gmem_thr_copy_O.partition_S(sO); // ((Atom,AtomNum),ATOM_M,ATOM_N) Tensor tOsO = gmem_thr_copy_O.partition_S(sO); // ((Atom,AtomNum),ATOM_M,ATOM_N)
Tensor tOgO = gmem_thr_copy_O.partition_D(gO); Tensor tOgO = gmem_thr_copy_O.partition_D(gO);
__syncthreads(); __syncthreads();
Tensor tOrO = make_tensor<Element>(shape(tOgO)); Tensor tOrO = make_tensor<Element>(shape(tOgO));
copy(gmem_thr_copy_O, tOsO, tOrO); cute::copy(gmem_tiled_copy_O, tOsO, tOrO);
Tensor caccO = make_identity_tensor(Shape<Int<kBlockM>, Int<kHeadDim>>{}); // (BLK_M,BLK_K) -> (blk_m,blk_k) Tensor caccO = make_identity_tensor(Shape<Int<kBlockM>, Int<kHeadDim>>{}); // (BLK_M,BLK_K) -> (blk_m,blk_k)
Tensor taccOcO = thr_mma.partition_C(caccO); // (MMA,MMA_M,MMA_K) Tensor taccOcO = thr_mma.partition_C(caccO); // (MMA,MMA_M,MMA_K)
@ -548,14 +603,15 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(0, 0, k)) < params.d; } for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(0, 0, k)) < params.d; }
} }
// Clear_OOB_K must be false since we don't want to write zeros to gmem // Clear_OOB_K must be false since we don't want to write zeros to gmem
flash::copy</*Is_even_MN=*/false, Is_even_K, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>( flash::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
gmem_thr_copy_O, tOrO, tOgO, tOcO, tOpO, binfo.actual_seqlen_q - m_block * kBlockM gmem_tiled_copy_O, tOrO, tOgO, tOcO, tOpO, binfo.actual_seqlen_q - m_block * kBlockM
); );
} }
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_even_N, bool Is_even_K, bool Return_softmax, typename Params> template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Return_softmax, typename Params>
inline __device__ void compute_attn(const Params &params) { inline __device__ void compute_attn(const Params &params) {
const int m_block = blockIdx.x; const int m_block = blockIdx.x;
// The block index for the batch. // The block index for the batch.
@ -571,7 +627,7 @@ inline __device__ void compute_attn(const Params &params) {
// the attention matrix. This way, as long as we have the batch, head, and the location of // the attention matrix. This way, as long as we have the batch, head, and the location of
// the 16 x 32 block within the attention matrix, we can generate the exact same dropout pattern. // the 16 x 32 block within the attention matrix, we can generate the exact same dropout pattern.
flash::compute_attn_1rowblock<Kernel_traits, Is_dropout, Is_causal, Is_even_N, Is_even_K, Return_softmax>(params, bidb, bidh, m_block); flash::compute_attn_1rowblock<Kernel_traits, Is_dropout, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K, Return_softmax>(params, bidb, bidh, m_block);
} }
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -4,15 +4,14 @@
#pragma once #pragma once
// #include <ATen/cuda/CUDAContext.h>
#include "static_switch.h" #include "static_switch.h"
#include "flash.h" #include "flash.h"
#include "flash_fwd_kernel.h" #include "flash_fwd_kernel.h"
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_even_N, bool Is_even_K, bool Return_softmax> template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Return_softmax>
__global__ void flash_fwd_kernel(Flash_fwd_params params) { __global__ void flash_fwd_kernel(Flash_fwd_params params) {
flash::compute_attn<Kernel_traits, Is_dropout, Is_causal, Is_even_N, Is_even_K, Return_softmax>(params); static_assert(!(Is_causal && Is_local)); // If Is_local is true, Is_causal should be false
flash::compute_attn<Kernel_traits, Is_dropout, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K, Return_softmax>(params);
} }
template<typename Kernel_traits, bool Is_dropout, bool Is_causal> template<typename Kernel_traits, bool Is_dropout, bool Is_causal>
@ -26,35 +25,39 @@ void run_flash_fwd(Flash_fwd_params &params, cudaStream_t stream) {
const int num_m_block = (params.seqlen_q + Kernel_traits::kBlockM - 1) / Kernel_traits::kBlockM; const int num_m_block = (params.seqlen_q + Kernel_traits::kBlockM - 1) / Kernel_traits::kBlockM;
dim3 grid(num_m_block, params.b, params.h); dim3 grid(num_m_block, params.b, params.h);
// We also use is_even_N to set Unpadded in the BlockInfo constructor, so we need to check const bool is_even_MN = params.cu_seqlens_q == nullptr && params.cu_seqlens_k == nullptr && params.seqlen_k % Kernel_traits::kBlockN == 0 && params.seqlen_q % Kernel_traits::kBlockM == 0;
// for cu_seqlens_q as well.
const bool is_even_N = params.cu_seqlens_q == nullptr && params.cu_seqlens_k == nullptr && params.seqlen_k % Kernel_traits::kBlockN == 0;
const bool is_even_K = params.d == Kernel_traits::kHeadDim; const bool is_even_K = params.d == Kernel_traits::kHeadDim;
const bool return_softmax = params.p_ptr != nullptr; const bool return_softmax = params.p_ptr != nullptr;
BOOL_SWITCH(is_even_N, IsEvenNConst, [&] { BOOL_SWITCH(is_even_MN, IsEvenMNConst, [&] {
BOOL_SWITCH(is_even_K, IsEvenKConst, [&] { BOOL_SWITCH(is_even_K, IsEvenKConst, [&] {
BOOL_SWITCH(return_softmax, ReturnSoftmaxConst, [&] { BOOL_SWITCH((params.window_size_left >= 0 || params.window_size_right >= 0) && !Is_causal, Is_local, [&] {
// Will only return softmax if dropout, to reduce compilation time. BOOL_SWITCH(return_softmax, ReturnSoftmaxConst, [&] {
auto kernel = &flash_fwd_kernel<Kernel_traits, Is_dropout, Is_causal, IsEvenNConst, IsEvenKConst, ReturnSoftmaxConst && Is_dropout>; BOOL_SWITCH(params.alibi_slopes_ptr != nullptr, Has_alibi, [&] {
// auto kernel = &flash_fwd_kernel<Kernel_traits, Is_dropout, Is_causal, IsEvenNConst, true, ReturnSoftmaxConst && Is_dropout>; // Will only return softmax if dropout, to reduce compilation time.
// if (smem_size >= 48 * 1024) { // If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates.
// C10_CUDA_CHECK(cudaFuncSetAttribute( // If return_softmax, set IsEvenMNConst to false to reduce number of templates
// kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); // If head dim > 128, set IsEvenMNConst to false to reduce number of templates
// } // If Is_local, set Is_causal to false
int ctas_per_sm; auto kernel = &flash_fwd_kernel<Kernel_traits, Is_dropout, Is_causal, Is_local && !Is_causal, Has_alibi, IsEvenMNConst && IsEvenKConst && !Is_local && !ReturnSoftmaxConst && Kernel_traits::kHeadDim <= 128, IsEvenKConst, ReturnSoftmaxConst && Is_dropout>;
cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor( // auto kernel = &flash_fwd_kernel<Kernel_traits, false, Is_causal, false, false, true, true, false>;
&ctas_per_sm, kernel, Kernel_traits::kNThreads, smem_size); // printf("IsEvenMNConst = %d, IsEvenKConst = %d, Is_local = %d, Is_causal = %d, ReturnSoftmaxConst = %d, Is_dropout = %d\n", int(IsEvenMNConst), int(IsEvenKConst), int(Is_local), int(Is_causal), int(ReturnSoftmaxConst), int(Is_dropout));
// printf("smem_size = %d, CTAs per SM = %d\n", int(smem_size), ctas_per_sm); // auto kernel = &flash_fwd_kernel<Kernel_traits, false, Is_causal, false, true, true, false>;
kernel<<<grid, Kernel_traits::kNThreads, smem_size, stream>>>(params); // int ctas_per_sm;
// C10_CUDA_KERNEL_LAUNCH_CHECK(); // cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
// &ctas_per_sm, kernel, Kernel_traits::kNThreads, smem_size);
// printf("smem_size = %d, CTAs per SM = %d\n", int(smem_size), ctas_per_sm);
kernel<<<grid, Kernel_traits::kNThreads, smem_size, stream>>>(params);
});
});
}); });
}); });
}); });
} }
template<typename T> template<typename T>
void run_mha_fwd_hdim32(Flash_fwd_params &params, cudaStream_t stream) { void run_mha_fwd_hdim32(Flash_fwd_params &params, cudaStream_t stream) {
constexpr int Headdim = 32; constexpr static int Headdim = 32;
BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
BOOL_SWITCH(params.is_causal, Is_causal, [&] { BOOL_SWITCH(params.is_causal, Is_causal, [&] {
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 128, 4, false, false, T>, Is_dropout, Is_causal>(params, stream); run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 128, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
@ -64,7 +67,7 @@ void run_mha_fwd_hdim32(Flash_fwd_params &params, cudaStream_t stream) {
template<typename T> template<typename T>
void run_mha_fwd_hdim64(Flash_fwd_params &params, cudaStream_t stream) { void run_mha_fwd_hdim64(Flash_fwd_params &params, cudaStream_t stream) {
constexpr int Headdim = 64; constexpr static int Headdim = 64;
BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
BOOL_SWITCH(params.is_causal, Is_causal, [&] { BOOL_SWITCH(params.is_causal, Is_causal, [&] {
if constexpr(!Is_dropout) { if constexpr(!Is_dropout) {
@ -86,7 +89,7 @@ void run_mha_fwd_hdim64(Flash_fwd_params &params, cudaStream_t stream) {
template<typename T> template<typename T>
void run_mha_fwd_hdim96(Flash_fwd_params &params, cudaStream_t stream) { void run_mha_fwd_hdim96(Flash_fwd_params &params, cudaStream_t stream) {
constexpr int Headdim = 96; constexpr static int Headdim = 96;
// auto dprops = at::cuda::getCurrentDeviceProperties(); // auto dprops = at::cuda::getCurrentDeviceProperties();
bool is_sm8x = true; // dprops->major == 8 && dprops->minor > 0; bool is_sm8x = true; // dprops->major == 8 && dprops->minor > 0;
BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
@ -112,7 +115,7 @@ void run_mha_fwd_hdim96(Flash_fwd_params &params, cudaStream_t stream) {
template<typename T> template<typename T>
void run_mha_fwd_hdim128(Flash_fwd_params &params, cudaStream_t stream) { void run_mha_fwd_hdim128(Flash_fwd_params &params, cudaStream_t stream) {
constexpr int Headdim = 128; constexpr static int Headdim = 128;
// auto dprops = at::cuda::getCurrentDeviceProperties(); // auto dprops = at::cuda::getCurrentDeviceProperties();
bool is_sm8x = true; // dprops->major == 8 && dprops->minor > 0; bool is_sm8x = true; // dprops->major == 8 && dprops->minor > 0;
BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
@ -149,7 +152,7 @@ void run_mha_fwd_hdim128(Flash_fwd_params &params, cudaStream_t stream) {
template<typename T> template<typename T>
void run_mha_fwd_hdim160(Flash_fwd_params &params, cudaStream_t stream) { void run_mha_fwd_hdim160(Flash_fwd_params &params, cudaStream_t stream) {
constexpr int Headdim = 160; constexpr static int Headdim = 160;
// auto dprops = at::cuda::getCurrentDeviceProperties(); // auto dprops = at::cuda::getCurrentDeviceProperties();
bool is_sm8x = true; // dprops->major == 8 && dprops->minor > 0; bool is_sm8x = true; // dprops->major == 8 && dprops->minor > 0;
BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
@ -179,7 +182,7 @@ void run_mha_fwd_hdim160(Flash_fwd_params &params, cudaStream_t stream) {
template<typename T> template<typename T>
void run_mha_fwd_hdim192(Flash_fwd_params &params, cudaStream_t stream) { void run_mha_fwd_hdim192(Flash_fwd_params &params, cudaStream_t stream) {
constexpr int Headdim = 192; constexpr static int Headdim = 192;
BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
BOOL_SWITCH(params.is_causal, Is_causal, [&] { BOOL_SWITCH(params.is_causal, Is_causal, [&] {
if constexpr(!Is_dropout) { if constexpr(!Is_dropout) {
@ -198,7 +201,7 @@ void run_mha_fwd_hdim192(Flash_fwd_params &params, cudaStream_t stream) {
template<typename T> template<typename T>
void run_mha_fwd_hdim224(Flash_fwd_params &params, cudaStream_t stream) { void run_mha_fwd_hdim224(Flash_fwd_params &params, cudaStream_t stream) {
constexpr int Headdim = 224; constexpr static int Headdim = 224;
int device; int device;
cudaGetDevice(&device); cudaGetDevice(&device);
int max_smem_per_block; int max_smem_per_block;
@ -224,7 +227,7 @@ void run_mha_fwd_hdim224(Flash_fwd_params &params, cudaStream_t stream) {
template<typename T> template<typename T>
void run_mha_fwd_hdim256(Flash_fwd_params &params, cudaStream_t stream) { void run_mha_fwd_hdim256(Flash_fwd_params &params, cudaStream_t stream) {
constexpr int Headdim = 256; constexpr static int Headdim = 256;
int device; int device;
cudaGetDevice(&device); cudaGetDevice(&device);
int max_smem_per_sm, max_smem_per_block; int max_smem_per_sm, max_smem_per_block;

View File

@ -91,17 +91,20 @@ struct Flash_fwd_kernel_traits : public Base {
SmemLayoutAtomQ{}, SmemLayoutAtomQ{},
Shape<Int<kBlockN>, Int<kHeadDim>>{})); Shape<Int<kBlockN>, Int<kHeadDim>>{}));
// This has to be kBlockN and not 8, otherwise we get wrong results for d=128
using SmemLayoutAtomVtransposedNoSwizzle = Layout<Shape<Int<kBlockKSmem>, Int<kBlockN>>,
Stride<_1, Int<kBlockKSmem>>>;
using SmemLayoutAtomVtransposed = decltype( using SmemLayoutAtomVtransposed = decltype(
composition(Swizzle<kSwizzle, 3, 3>{}, composition(Swizzle<kSwizzle, 3, 3>{}, SmemLayoutAtomVtransposedNoSwizzle{}));
// This has to be kBlockN and not 8, otherwise we get wrong results for d=128
Layout<Shape<Int<kBlockKSmem>, Int<kBlockN>>,
Stride<_1, Int<kBlockKSmem>>>{}));
using SmemLayoutVtransposed = decltype(tile_to_shape( using SmemLayoutVtransposed = decltype(tile_to_shape(
SmemLayoutAtomVtransposed{}, SmemLayoutAtomVtransposed{},
Shape<Int<kHeadDim>, Int<kBlockN>>{})); Shape<Int<kHeadDim>, Int<kBlockN>>{}));
// Maybe the VtransposeNoSwizzle just needs to have the right shape // Maybe the VtransposeNoSwizzle just needs to have the right shape
// And the strides don't matter? // And the strides don't matter?
using SmemLayoutVtransposedNoSwizzle = decltype(SmemLayoutVtransposed{}.layout_fn()); using SmemLayoutVtransposedNoSwizzle = decltype(tile_to_shape(
SmemLayoutAtomVtransposedNoSwizzle{},
Shape<Int<kHeadDim>, Int<kBlockN>>{}));
// using SmemLayoutVtransposedNoSwizzle = decltype(SmemLayoutVtransposed{}.layout_fn());
using SmemLayoutAtomO = decltype( using SmemLayoutAtomO = decltype(
composition(Swizzle<kSwizzle, 3, 3>{}, composition(Swizzle<kSwizzle, 3, 3>{},
@ -110,7 +113,8 @@ struct Flash_fwd_kernel_traits : public Base {
using SmemLayoutO = decltype(tile_to_shape( using SmemLayoutO = decltype(tile_to_shape(
SmemLayoutAtomO{}, SmemLayoutAtomO{},
Shape<Int<kBlockM>, Int<kHeadDim>>{})); Shape<Int<kBlockM>, Int<kHeadDim>>{}));
using SmemCopyAtomO = Copy_Atom<DefaultCopy, elem_type>; using SmemCopyAtomO = Copy_Atom<DefaultCopy, Element>;
using SmemCopyAtomOaccum = Copy_Atom<DefaultCopy, ElementAccum>;
static constexpr int kSmemQCount = size(SmemLayoutQ{}); static constexpr int kSmemQCount = size(SmemLayoutQ{});
static constexpr int kSmemKVCount = size(SmemLayoutKV{}) * 2; static constexpr int kSmemKVCount = size(SmemLayoutKV{}) * 2;
@ -138,11 +142,11 @@ struct Flash_fwd_kernel_traits : public Base {
DefaultCopy DefaultCopy
>; >;
using GmemTiledCopyQKV = decltype( using GmemTiledCopyQKV = decltype(
make_tiled_copy(Copy_Atom<Gmem_copy_struct, elem_type>{}, make_tiled_copy(Copy_Atom<Gmem_copy_struct, Element>{},
GmemLayoutAtom{}, GmemLayoutAtom{},
Layout<Shape<_1, _8>>{})); // Val layout, 8 vals per read Layout<Shape<_1, _8>>{})); // Val layout, 8 vals per read
using GmemTiledCopyO = decltype( using GmemTiledCopyO = decltype(
make_tiled_copy(Copy_Atom<DefaultCopy, elem_type>{}, make_tiled_copy(Copy_Atom<DefaultCopy, Element>{},
GmemLayoutAtom{}, GmemLayoutAtom{},
Layout<Shape<_1, _8>>{})); // Val layout, 8 vals per store Layout<Shape<_1, _8>>{})); // Val layout, 8 vals per store
static constexpr int kGmemThreadsPerRowP = kBlockN / kGmemElemsPerLoad; static constexpr int kGmemThreadsPerRowP = kBlockN / kGmemElemsPerLoad;
@ -151,10 +155,30 @@ struct Flash_fwd_kernel_traits : public Base {
Stride<Int<kGmemThreadsPerRowP>, _1>>; Stride<Int<kGmemThreadsPerRowP>, _1>>;
using GmemTiledCopyP = decltype( using GmemTiledCopyP = decltype(
make_tiled_copy(Copy_Atom<DefaultCopy, elem_type>{}, make_tiled_copy(Copy_Atom<DefaultCopy, Element>{},
GmemLayoutAtomP{}, GmemLayoutAtomP{},
Layout<Shape<_1, _8>>{})); // Val layout, 8 vals per store Layout<Shape<_1, _8>>{})); // Val layout, 8 vals per store
using GmemLayoutAtomOaccum = std::conditional_t<
kBlockKSmem == 32,
Layout<Shape <_16, _8>, // Thread layout, 8 threads per row
Stride< _8, _1>>,
Layout<Shape <_8, _16>, // Thread layout, 16 threads per row
Stride< _16, _1>>
>;
using GmemTiledCopyOaccum = decltype(
make_tiled_copy(Copy_Atom<DefaultCopy, ElementAccum>{},
GmemLayoutAtomOaccum{},
Layout<Shape < _1, _4>>{})); // Val layout, 4 vals per store
using GmemLayoutAtomRotcossin = GmemLayoutAtom;
using GmemTiledCopyRotcossin = decltype(
make_tiled_copy(Copy_Atom<UniversalCopy<uint64_t>, Element>{},
GmemLayoutAtomRotcossin{},
Layout<Shape < _1, _4>>{})); // Val layout, 4 vals per load
using GmemTiledCopyRotcossinCont = decltype(
make_tiled_copy(Copy_Atom<DefaultCopy, Element>{},
GmemLayoutAtomRotcossin{},
Layout<Shape < _1, _8>>{})); // Val layout, 8 vals per load
}; };
// Is_V_in_regs is an option to reduce smem usage, but will increase register pressue. // Is_V_in_regs is an option to reduce smem usage, but will increase register pressue.
@ -223,16 +247,19 @@ struct Flash_bwd_kernel_traits : public Base {
SmemLayoutAtomKV{}, SmemLayoutAtomKV{},
make_shape(Int<kBlockN>{}, Int<kHeadDim>{}))); make_shape(Int<kBlockN>{}, Int<kHeadDim>{})));
using SmemLayoutAtomKtransposedNoSwizzle = Layout<Shape<Int<kBlockKSmem>, Int<kBlockN>>,
Stride<_1, Int<kBlockKSmem>>>;
using SmemLayoutAtomKtransposed = decltype( using SmemLayoutAtomKtransposed = decltype(
composition(Swizzle<kSwizzle, 3, 3>{}, composition(Swizzle<kSwizzle, 3, 3>{}, SmemLayoutAtomKtransposedNoSwizzle{}));
Layout<Shape<Int<kBlockKSmem>, Int<kBlockN>>,
Stride<_1, Int<kBlockKSmem>>>{}));
using SmemLayoutKtransposed = decltype(tile_to_shape( using SmemLayoutKtransposed = decltype(tile_to_shape(
SmemLayoutAtomKtransposed{}, SmemLayoutAtomKtransposed{},
make_shape(Int<kHeadDim>{}, Int<kBlockN>{}))); make_shape(Int<kHeadDim>{}, Int<kBlockN>{})));
// Maybe the KtransposeNoSwizzle just needs to have the right shape // Maybe the KtransposeNoSwizzle just needs to have the right shape
// And the strides don't matter? // And the strides don't matter?
using SmemLayoutKtransposedNoSwizzle = decltype(SmemLayoutKtransposed{}.layout_fn()); using SmemLayoutKtransposedNoSwizzle = decltype(tile_to_shape(
SmemLayoutAtomKtransposedNoSwizzle{},
make_shape(Int<kHeadDim>{}, Int<kBlockN>{})));
// using SmemLayoutKtransposedNoSwizzle = decltype(SmemLayoutKtransposed{}.layout_fn());
// TODO: generalize to other values of kBlockN // TODO: generalize to other values of kBlockN
// TODO: what should be the Swizzle here? 3 is faster than 1, and 1 is faster than 2 // TODO: what should be the Swizzle here? 3 is faster than 1, and 1 is faster than 2
@ -250,24 +277,30 @@ struct Flash_bwd_kernel_traits : public Base {
using SmemLayoutPdS = decltype(tile_to_shape( using SmemLayoutPdS = decltype(tile_to_shape(
SmemLayoutAtomPdS{}, SmemLayoutAtomPdS{},
make_shape(Int<kBlockM>{}, Int<kBlockN>{}))); make_shape(Int<kBlockM>{}, Int<kBlockN>{})));
using SmemLayoutAtomPdStransposedNoSwizzle = Layout<Shape<Int<kPBlockN>, Int<kBlockM>>,
Stride<_1, Int<kPBlockN>>>;
using SmemLayoutAtomPdStransposed = decltype( using SmemLayoutAtomPdStransposed = decltype(
composition(Swizzle<kSwizzlePdS, 3, 3>{}, composition(Swizzle<kSwizzlePdS, 3, 3>{}, SmemLayoutAtomPdStransposedNoSwizzle{}));
Layout<Shape<Int<kPBlockN>, Int<kBlockM>>,
Stride<_1, Int<kPBlockN>>>{}));
using SmemLayoutPdStransposed = decltype(tile_to_shape( using SmemLayoutPdStransposed = decltype(tile_to_shape(
SmemLayoutAtomPdStransposed{}, SmemLayoutAtomPdStransposed{},
make_shape(Int<kBlockN>{}, Int<kBlockM>{}))); make_shape(Int<kBlockN>{}, Int<kBlockM>{})));
using SmemLayoutPdStransposedNoSwizzle = decltype(SmemLayoutPdStransposed{}.layout_fn()); using SmemLayoutPdStransposedNoSwizzle = decltype(tile_to_shape(
SmemLayoutAtomPdStransposedNoSwizzle{},
make_shape(Int<kBlockN>{}, Int<kBlockM>{})));
// using SmemLayoutPdStransposedNoSwizzle = decltype(SmemLayoutPdStransposed{}.layout_fn());
using SmemCopyAtomPdS = Copy_Atom<DefaultCopy, elem_type>; using SmemCopyAtomPdS = Copy_Atom<DefaultCopy, elem_type>;
using SmemLayoutAtomQdOtransposedNoSwizzle = Layout<Shape<Int<kBlockKSmem>, Int<kBlockM>>,
Stride<_1, Int<kBlockKSmem>>>;
using SmemLayoutAtomQdOtransposed = decltype( using SmemLayoutAtomQdOtransposed = decltype(
composition(Swizzle<kSwizzle, 3, 3>{}, composition(Swizzle<kSwizzle, 3, 3>{}, SmemLayoutAtomQdOtransposedNoSwizzle{}));
Layout<Shape<Int<kBlockKSmem>, Int<kBlockM>>,
Stride<_1, Int<kBlockKSmem>>>{}));
using SmemLayoutQdOtransposed = decltype(tile_to_shape( using SmemLayoutQdOtransposed = decltype(tile_to_shape(
SmemLayoutAtomQdOtransposed{}, SmemLayoutAtomQdOtransposed{},
make_shape(Int<kHeadDim>{}, Int<kBlockM>{}))); make_shape(Int<kHeadDim>{}, Int<kBlockM>{})));
using SmemLayoutQdOtransposedNoSwizzle = decltype(SmemLayoutQdOtransposed{}.layout_fn()); using SmemLayoutQdOtransposedNoSwizzle = decltype(tile_to_shape(
SmemLayoutAtomQdOtransposedNoSwizzle{},
make_shape(Int<kHeadDim>{}, Int<kBlockM>{})));
// using SmemLayoutQdOtransposedNoSwizzle = decltype(SmemLayoutQdOtransposed{}.layout_fn());
using SmemLayoutAtomdKV = decltype( using SmemLayoutAtomdKV = decltype(
composition(Swizzle<kSwizzle, 3, 3>{}, composition(Swizzle<kSwizzle, 3, 3>{},
@ -292,13 +325,11 @@ struct Flash_bwd_kernel_traits : public Base {
static constexpr int kSmemdSCount = size(SmemLayoutPdS{}); static constexpr int kSmemdSCount = size(SmemLayoutPdS{});
static constexpr int kSmemPCount = size(SmemLayoutPdS{}); static constexpr int kSmemPCount = size(SmemLayoutPdS{});
static constexpr int kSmemdQCount = size(SmemLayoutdQ{}); static constexpr int kSmemdQCount = size(SmemLayoutdQ{});
static constexpr int kSmemdPsumCount = kBlockM;
static constexpr int kSmemQdOSize = kSmemQdOCount * sizeof(Element); static constexpr int kSmemQdOSize = kSmemQdOCount * sizeof(Element);
static constexpr int kSmemKVSize = kSmemKVCount * sizeof(Element); static constexpr int kSmemKVSize = kSmemKVCount * sizeof(Element);
static constexpr int kSmemdSSize = kSmemdSCount * sizeof(Element); static constexpr int kSmemdSSize = kSmemdSCount * sizeof(Element);
static constexpr int kSmemPSize = kSmemPCount * sizeof(Element); static constexpr int kSmemPSize = kSmemPCount * sizeof(Element);
static constexpr int kSmemdQSize = kSmemdQCount * sizeof(Element); static constexpr int kSmemdQSize = kSmemdQCount * sizeof(Element);
static constexpr int kSmemdPsumSize = kSmemdPsumCount * sizeof(ElementAccum);
static constexpr int kSmemSize = kSmemQdOSize static constexpr int kSmemSize = kSmemQdOSize
+ (!Is_V_in_regs + (!Is_V_in_regs
? kSmemKVSize + kSmemdSSize + std::max(kSmemPSize, kSmemdQSize) ? kSmemKVSize + kSmemdSSize + std::max(kSmemPSize, kSmemdQSize)

View File

@ -0,0 +1,159 @@
/******************************************************************************
* Copyright (c) 2023, Tri Dao.
******************************************************************************/
#pragma once
#include "cute/algorithm/copy.hpp"
#include "cutlass/cutlass.h"
#include "cutlass/layout/layout.h"
#include <cutlass/numeric_types.h>
using namespace cute;
template<int kHeadDim_, int kBlockM_, int kBlockN_, int kNWarps_, typename elem_type=cutlass::half_t>
struct Flash_kernel_traits_sm90 {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
using Element = elem_type;
static constexpr bool Has_cp_async = true;
#else
using Element = cutlass::half_t;
static constexpr bool Has_cp_async = false;
#endif
using ElementAccum = float;
using index_t = uint32_t;
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
using MMA_Atom_Arch = std::conditional_t<
std::is_same_v<elem_type, cutlass::half_t>,
MMA_Atom<SM80_16x8x16_F32F16F16F32_TN>,
MMA_Atom<SM80_16x8x16_F32BF16BF16F32_TN>
>;
using ValLayoutMNK = Layout<Shape<_1, _2, _1>>;
#else
using MMA_Atom_Arch = MMA_Atom<SM75_16x8x8_F32F16F16F32_TN>;
using ValLayoutMNK = Layout<Shape<_1, _2, _2>>;
#endif
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 750
using SmemCopyAtom = Copy_Atom<SM75_U32x4_LDSM_N, elem_type>;
using SmemCopyAtomTransposed = Copy_Atom<SM75_U16x8_LDSM_T, elem_type>;
#else
using SmemCopyAtom = Copy_Atom<DefaultCopy, elem_type>;
using SmemCopyAtomTransposed = Copy_Atom<DefaultCopy, elem_type>;
#endif
};
template<int kHeadDim_, int kBlockM_, int kBlockN_, int kNWarps_, bool Is_Q_in_regs_=false, bool Share_Q_K_smem_=false, typename elem_type=cutlass::half_t,
typename Base=Flash_kernel_traits_sm90<kHeadDim_, kBlockM_, kBlockN_, kNWarps_, elem_type> >
struct Flash_fwd_kernel_traits : public Base {
using Element = typename Base::Element;
using ElementAccum = typename Base::ElementAccum;
using index_t = typename Base::index_t;
static constexpr bool Has_cp_async = Base::Has_cp_async;
using SmemCopyAtom = typename Base::SmemCopyAtom;
using SmemCopyAtomTransposed = typename Base::SmemCopyAtomTransposed;
static constexpr bool Share_Q_K_smem = Share_Q_K_smem_;
static constexpr bool Is_Q_in_regs = Is_Q_in_regs_ || Share_Q_K_smem;
// The number of threads.
static constexpr int kNWarps = kNWarps_;
static constexpr int kNThreads = kNWarps * 32;
static constexpr int kBlockM = kBlockM_;
static constexpr int kBlockN = kBlockN_;
static constexpr int kHeadDim = kHeadDim_;
static_assert(kHeadDim % 32 == 0);
static constexpr int kBlockKSmem = kHeadDim % 64 == 0 ? 64 : 32;
static constexpr int kBlockKGmem = kHeadDim % 128 == 0 ? 128 : (kHeadDim % 64 == 0 ? 64 : 32);
static constexpr int kSwizzle = kBlockKSmem == 32 ? 2 : 3;
using TiledMma = TiledMMA<
typename Base::MMA_Atom_Arch,
Layout<Shape<Int<kNWarps>,_1,_1>>, // 4x1x1 or 8x1x1 thread group
typename Base::ValLayoutMNK>; // 1x2x1 or 1x2x2 value group for 16x16x16 MMA and LDSM
using SmemLayoutAtomQ = decltype(
composition(Swizzle<kSwizzle, 3, 3>{},
// This has to be kBlockKSmem, using kHeadDim gives wrong results for d=128
Layout<Shape<_8, Int<kBlockKSmem>>,
Stride<Int<kBlockKSmem>, _1>>{}));
using SmemLayoutQ = decltype(tile_to_shape(
SmemLayoutAtomQ{},
Shape<Int<kBlockM>, Int<kHeadDim>>{}));
using SmemLayoutKV = decltype(tile_to_shape(
SmemLayoutAtomQ{},
Shape<Int<kBlockN>, Int<kHeadDim>>{}));
using SmemLayoutAtomVtransposed = decltype(
composition(Swizzle<kSwizzle, 3, 3>{},
// This has to be kBlockN and not 8, otherwise we get wrong results for d=128
Layout<Shape<Int<kBlockKSmem>, Int<kBlockN>>,
Stride<_1, Int<kBlockKSmem>>>{}));
using SmemLayoutVtransposed = decltype(tile_to_shape(
SmemLayoutAtomVtransposed{},
Shape<Int<kHeadDim>, Int<kBlockN>>{}));
// Maybe the VtransposeNoSwizzle just needs to have the right shape
// And the strides don't matter?
using SmemLayoutVtransposedNoSwizzle = decltype(SmemLayoutVtransposed{}.layout_fn());
using SmemLayoutAtomO = decltype(
composition(Swizzle<kSwizzle, 3, 3>{},
Layout<Shape<Int<8>, Int<kBlockKSmem>>,
Stride<Int<kBlockKSmem>, _1>>{}));
using SmemLayoutO = decltype(tile_to_shape(
SmemLayoutAtomO{},
Shape<Int<kBlockM>, Int<kHeadDim>>{}));
using SmemCopyAtomO = Copy_Atom<DefaultCopy, elem_type>;
static constexpr int kSmemQCount = size(SmemLayoutQ{});
static constexpr int kSmemKVCount = size(SmemLayoutKV{}) * 2;
static constexpr int kSmemQSize = kSmemQCount * sizeof(Element);
static constexpr int kSmemKVSize = kSmemKVCount * sizeof(Element);
static constexpr int kSmemSize = Share_Q_K_smem ? std::max(kSmemQSize, kSmemKVSize) : kSmemQSize + kSmemKVSize;
static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element);
static_assert(kHeadDim % kGmemElemsPerLoad == 0, "kHeadDim must be a multiple of kGmemElemsPerLoad");
// Using kBlockKSmem here is 6-10% faster than kBlockKGmem for d=128 because of bank conflicts.
// For example, for d=128, smem is split into 2 "pages", each page takes care of columns
// 0-63 and 64-127. If we have 16 threads per row for gmem read, when we write to smem,
// thread 0 - 7 will write to the first page and thread 8 - 15 will write to the second page,
// to the same banks.
static constexpr int kGmemThreadsPerRow = kBlockKSmem / kGmemElemsPerLoad;
static_assert(kNThreads % kGmemThreadsPerRow == 0, "kNThreads must be a multiple of kGmemThreadsPerRow");
using GmemLayoutAtom = Layout<Shape <Int<kNThreads / kGmemThreadsPerRow>, Int<kGmemThreadsPerRow>>,
Stride<Int<kGmemThreadsPerRow>, _1>>;
// We use CACHEGLOBAL instead of CACHEALWAYS for both Q and K/V, since we won't be reading
// from the same address by the same threadblock. This is slightly faster.
using Gmem_copy_struct = std::conditional_t<
Has_cp_async,
SM80_CP_ASYNC_CACHEGLOBAL<cute::uint128_t>,
DefaultCopy
>;
using GmemTiledCopyQKV = decltype(
make_tiled_copy(Copy_Atom<Gmem_copy_struct, elem_type>{},
GmemLayoutAtom{},
Layout<Shape<_1, _8>>{})); // Val layout, 8 vals per read
using GmemTiledCopyO = decltype(
make_tiled_copy(Copy_Atom<DefaultCopy, elem_type>{},
GmemLayoutAtom{},
Layout<Shape<_1, _8>>{})); // Val layout, 8 vals per store
static constexpr int kGmemThreadsPerRowP = kBlockN / kGmemElemsPerLoad;
static_assert(kNThreads % kGmemThreadsPerRowP == 0, "kNThreads must be a multiple of kGmemThreadsPerRowP");
using GmemLayoutAtomP = Layout<Shape <Int<kNThreads / kGmemThreadsPerRowP>, Int<kGmemThreadsPerRowP>>,
Stride<Int<kGmemThreadsPerRowP>, _1>>;
using GmemTiledCopyP = decltype(
make_tiled_copy(Copy_Atom<DefaultCopy, elem_type>{},
GmemLayoutAtomP{},
Layout<Shape<_1, _8>>{})); // Val layout, 8 vals per store
};
////////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -8,8 +8,7 @@
#include <cute/tensor.hpp> #include <cute/tensor.hpp>
#include <cutlass/cutlass.h> #include <cutlass/numeric_types.h>
#include <cutlass/array.h>
#include "philox.cuh" #include "philox.cuh"
#include "utils.h" #include "utils.h"
@ -117,15 +116,18 @@ inline __device__ void max_scale_exp2_sum(Tensor<Engine0, Layout0> &tensor, Tens
} }
template <typename Engine, typename Layout> template <typename Engine, typename Layout>
inline __device__ void apply_mask(Tensor<Engine, Layout> &tensor, const uint32_t max_seqlen_k) { inline __device__ void apply_mask(Tensor<Engine, Layout> &tensor, const int max_seqlen_k,
const int col_idx_offset_ = 0) {
// tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N)) // tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N))
static_assert(Layout::rank == 2, "Only support 2D Tensor"); static_assert(Layout::rank == 2, "Only support 2D Tensor");
const uint32_t lane_id = threadIdx.x % 32; const int lane_id = threadIdx.x % 32;
const int col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2;
#pragma unroll #pragma unroll
for (int nj = 0; nj < size<1, 1>(tensor); ++nj) { for (int nj = 0; nj < size<1, 1>(tensor); ++nj) {
const int col_idx_base = col_idx_offset + nj * 8;
#pragma unroll #pragma unroll
for (int j = 0; j < size<1, 0>(tensor); ++j) { for (int j = 0; j < size<1, 0>(tensor); ++j) {
const uint32_t col_idx = nj * 8 + j + (lane_id % 4) * 2; const int col_idx = col_idx_base + j;
if (col_idx >= max_seqlen_k) { if (col_idx >= max_seqlen_k) {
// Without the "make_coord" we get wrong results // Without the "make_coord" we get wrong results
#pragma unroll #pragma unroll
@ -137,30 +139,30 @@ inline __device__ void apply_mask(Tensor<Engine, Layout> &tensor, const uint32_t
} }
} }
template <typename Engine, typename Layout> template <bool HasWSLeft=true, typename Engine, typename Layout>
inline __device__ void apply_mask_causal(Tensor<Engine, Layout> &tensor, const uint32_t col_idx_offset_, inline __device__ void apply_mask_local(Tensor<Engine, Layout> &tensor, const int col_idx_offset_,
const uint32_t max_seqlen_k, const uint32_t row_idx_offset_, const int max_seqlen_k, const int row_idx_offset,
const uint32_t warp_row_stride) { const int max_seqlen_q, const int warp_row_stride,
const int window_size_left, const int window_size_right) {
// tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N)) // tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N))
static_assert(Layout::rank == 2, "Only support 2D Tensor"); static_assert(Layout::rank == 2, "Only support 2D Tensor");
const uint32_t lane_id = threadIdx.x % 32; const int lane_id = threadIdx.x % 32;
// const uint32_t row_idx_offset = row_idx_offset_ + lane_id / 4; const int col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2;
const uint32_t row_idx_offset = row_idx_offset_;
const uint32_t col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2;
#pragma unroll #pragma unroll
for (int mi = 0; mi < size<0, 1>(tensor); ++mi) { for (int mi = 0; mi < size<0, 1>(tensor); ++mi) {
const uint32_t row_idx_base = row_idx_offset + mi * warp_row_stride; const int row_idx_base = row_idx_offset + mi * warp_row_stride;
#pragma unroll #pragma unroll
for (int i = 0; i < size<0, 0>(tensor); ++i) { for (int i = 0; i < size<0, 0>(tensor); ++i) {
const uint32_t row_idx = row_idx_base + i * 8; const int row_idx = row_idx_base + i * 8;
const uint32_t col_idx_limit = std::min(max_seqlen_k, row_idx + 1); const int col_idx_limit_left = std::max(0, row_idx + max_seqlen_k - max_seqlen_q - window_size_left);
const int col_idx_limit_right = std::min(max_seqlen_k, row_idx + 1 + max_seqlen_k - max_seqlen_q + window_size_right);
#pragma unroll #pragma unroll
for (int nj = 0; nj < size<1, 1>(tensor); ++nj) { for (int nj = 0; nj < size<1, 1>(tensor); ++nj) {
const uint32_t col_idx_base = col_idx_offset + nj * 8; const int col_idx_base = col_idx_offset + nj * 8;
#pragma unroll #pragma unroll
for (int j = 0; j < size<1, 0>(tensor); ++j) { for (int j = 0; j < size<1, 0>(tensor); ++j) {
const uint32_t col_idx = col_idx_base + j; const int col_idx = col_idx_base + j;
if (col_idx >= col_idx_limit) { if (col_idx >= col_idx_limit_right || (HasWSLeft && col_idx < col_idx_limit_left)) {
tensor(make_coord(i, mi), make_coord(j, nj)) = -INFINITY; tensor(make_coord(i, mi), make_coord(j, nj)) = -INFINITY;
} }
} }
@ -174,10 +176,19 @@ inline __device__ void apply_mask_causal(Tensor<Engine, Layout> &tensor, const u
} }
} }
template <typename Engine, typename Layout>
inline __device__ void apply_mask_causal(Tensor<Engine, Layout> &tensor, const int col_idx_offset_,
const int max_seqlen_k, const int row_idx_offset,
const int max_seqlen_q, const int warp_row_stride) {
// Causal masking is equivalent to local masking with window_size_left = infinity and window_size_right = 0
apply_mask_local</*HasWSLeft=*/false>(tensor, col_idx_offset_, max_seqlen_k, row_idx_offset,
max_seqlen_q, warp_row_stride, -1, 0);
}
template <typename Engine0, typename Layout0, typename Engine1, typename Layout1> template <typename Engine0, typename Layout0, typename Engine1, typename Layout1>
inline __device__ void apply_mask_causal_w_idx( inline __device__ void apply_mask_causal_w_idx(
Tensor<Engine0, Layout0> &tensor, Tensor<Engine1, Layout1> const &idx_rowcol, Tensor<Engine0, Layout0> &tensor, Tensor<Engine1, Layout1> const &idx_rowcol,
const uint32_t col_idx_offset_, const uint32_t max_seqlen_k, const uint32_t row_idx_offset_) const int col_idx_offset_, const int max_seqlen_k, const int row_idx_offset)
{ {
// tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N)) // tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N))
static_assert(Layout0::rank == 2, "Only support 2D Tensor"); static_assert(Layout0::rank == 2, "Only support 2D Tensor");
@ -186,7 +197,7 @@ inline __device__ void apply_mask_causal_w_idx(
CUTE_STATIC_ASSERT_V(size<1>(tensor) == size<1>(idx_rowcol)); CUTE_STATIC_ASSERT_V(size<1>(tensor) == size<1>(idx_rowcol));
#pragma unroll #pragma unroll
for (int mi = 0; mi < size<0>(tensor); ++mi) { for (int mi = 0; mi < size<0>(tensor); ++mi) {
const uint32_t col_idx_limit = std::min(max_seqlen_k, 1 + row_idx_offset_ + get<0>(idx_rowcol(mi, 0))); const int col_idx_limit = std::min(max_seqlen_k, 1 + row_idx_offset + get<0>(idx_rowcol(mi, 0)));
#pragma unroll #pragma unroll
for (int ni = 0; ni < size<1, 1>(tensor); ++ni) { for (int ni = 0; ni < size<1, 1>(tensor); ++ni) {
if (col_idx_offset_ + get<1>(idx_rowcol(0, ni)) >= col_idx_limit) { if (col_idx_offset_ + get<1>(idx_rowcol(0, ni)) >= col_idx_limit) {
@ -204,8 +215,8 @@ inline __device__ void apply_mask_causal_w_idx(
template <bool encode_dropout_in_sign_bit=false, typename Engine, typename Layout> template <bool encode_dropout_in_sign_bit=false, typename Engine, typename Layout>
inline __device__ void apply_dropout(Tensor<Engine, Layout> &tensor, uint8_t p_dropout_in_uint8_t, inline __device__ void apply_dropout(Tensor<Engine, Layout> &tensor, uint8_t p_dropout_in_uint8_t,
unsigned long long seed, unsigned long long offset, unsigned long long seed, unsigned long long offset,
uint32_t block_row_start, uint32_t block_col_start, int block_row_start, int block_col_start,
uint32_t block_row_stride) { int block_row_stride) {
// tensor has shape (8, MMA_M, MMA_N / 2) // tensor has shape (8, MMA_M, MMA_N / 2)
using T = typename Engine::value_type; using T = typename Engine::value_type;
auto encode_dropout = [](bool keep, T val) { auto encode_dropout = [](bool keep, T val) {

View File

@ -87,46 +87,6 @@ inline __device__ uint32_t convert_relu2<cutlass::bfloat16_t>(const float2 x) {
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T>
inline __device__ float2 half2_unpack(uint32_t a);
template <>
inline __device__ float2 half2_unpack<__half>(uint32_t a) {
return __half22float2(reinterpret_cast<__half2 (&)>(a));
}
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
template <>
inline __device__ float2 half2_unpack<__nv_bfloat16>(uint32_t a) {
return __bfloat1622float2(reinterpret_cast<__nv_bfloat162 (&)>(a));
}
#endif
////////////////////////////////////////////////////////////////////////////////////////////////////
// Convert two half2's or bf162's into float, then take their dot product.
template <typename T>
inline __device__ float hfma2_to_float(const uint32_t a, const uint32_t b) {
float2 af = flash::half2_unpack<T>(a);
float2 bf = flash::half2_unpack<T>(b);
return af.x * bf.x + af.y * bf.y;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
// Converted two vectors of 8 half's or bf16's into float, then take their dot product.
template<typename T>
inline __device__ float hmulsum8(const uint4 a, const uint4 b) {
float sum;
sum = flash::hfma2_to_float<T>(a.x, b.x);
sum += flash::hfma2_to_float<T>(a.y, b.y);
sum += flash::hfma2_to_float<T>(a.z, b.z);
sum += flash::hfma2_to_float<T>(a.w, b.w);
return sum;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T> template<typename T>
struct MaxOp { struct MaxOp {
__device__ inline T operator()(T const & x, T const & y) { return x > y ? x : y; } __device__ inline T operator()(T const & x, T const & y) { return x > y ? x : y; }
@ -173,10 +133,12 @@ static __device__ inline T run(T x, Operator &op) {
template<bool A_in_regs=false, bool B_in_regs=false, typename Tensor0, typename Tensor1, template<bool A_in_regs=false, bool B_in_regs=false, typename Tensor0, typename Tensor1,
typename Tensor2, typename Tensor3, typename Tensor4, typename Tensor2, typename Tensor3, typename Tensor4,
typename TiledMma, typename TiledCopy0, typename TiledCopy1> typename TiledMma, typename TiledCopyA, typename TiledCopyB,
typename ThrCopyA, typename ThrCopyB>
inline __device__ void gemm(Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCrB, Tensor3 const& tCsA, inline __device__ void gemm(Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCrB, Tensor3 const& tCsA,
Tensor4 const& tCsB, TiledMma tiled_mma, Tensor4 const& tCsB, TiledMma tiled_mma,
TiledCopy0 smem_thr_copy_A, TiledCopy1 smem_thr_copy_B) { TiledCopyA smem_tiled_copy_A, TiledCopyB smem_tiled_copy_B,
ThrCopyA smem_thr_copy_A, ThrCopyB smem_thr_copy_B) {
CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(acc)); // MMA_M CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(acc)); // MMA_M
CUTE_STATIC_ASSERT_V(size<1>(tCrB) == size<2>(acc)); // MMA_N CUTE_STATIC_ASSERT_V(size<1>(tCrB) == size<2>(acc)); // MMA_N
CUTE_STATIC_ASSERT_V(size<2>(tCrA) == size<2>(tCrB)); // MMA_K CUTE_STATIC_ASSERT_V(size<2>(tCrA) == size<2>(tCrB)); // MMA_K
@ -184,13 +146,13 @@ inline __device__ void gemm(Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCrB, Tensor3
CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(tCrA_copy_view)); // M CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(tCrA_copy_view)); // M
Tensor tCrB_copy_view = smem_thr_copy_B.retile_D(tCrB); Tensor tCrB_copy_view = smem_thr_copy_B.retile_D(tCrB);
CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<1>(tCrB_copy_view)); // N CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<1>(tCrB_copy_view)); // N
if (!A_in_regs) { copy(smem_thr_copy_A, tCsA(_, _, _0{}), tCrA_copy_view(_, _, _0{})); } if (!A_in_regs) { cute::copy(smem_tiled_copy_A, tCsA(_, _, _0{}), tCrA_copy_view(_, _, _0{})); }
if (!B_in_regs) { copy(smem_thr_copy_B, tCsB(_, _, _0{}), tCrB_copy_view(_, _, _0{})); } if (!B_in_regs) { cute::copy(smem_tiled_copy_B, tCsB(_, _, _0{}), tCrB_copy_view(_, _, _0{})); }
#pragma unroll #pragma unroll
for (int i = 0; i < size<2>(tCrA); ++i) { for (int i = 0; i < size<2>(tCrA); ++i) {
if (i < size<2>(tCrA) - 1) { if (i < size<2>(tCrA) - 1) {
if (!A_in_regs) { copy(smem_thr_copy_A, tCsA(_, _, i + 1), tCrA_copy_view(_, _, i + 1)); } if (!A_in_regs) { cute::copy(smem_tiled_copy_A, tCsA(_, _, i + 1), tCrA_copy_view(_, _, i + 1)); }
if (!B_in_regs) { copy(smem_thr_copy_B, tCsB(_, _, i + 1), tCrB_copy_view(_, _, i + 1)); } if (!B_in_regs) { cute::copy(smem_tiled_copy_B, tCsB(_, _, i + 1), tCrB_copy_view(_, _, i + 1)); }
} }
cute::gemm(tiled_mma, tCrA(_, _, i), tCrB(_, _, i), acc); cute::gemm(tiled_mma, tCrA(_, _, i), tCrB(_, _, i), acc);
} }
@ -199,19 +161,20 @@ inline __device__ void gemm(Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCrB, Tensor3
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename Tensor0, typename Tensor1, typename Tensor2, typename Tensor3, template<typename Tensor0, typename Tensor1, typename Tensor2, typename Tensor3,
typename TiledMma, typename TiledCopy> typename TiledMma, typename TiledCopy, typename ThrCopy>
inline __device__ void gemm_A_in_regs(Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCrB, Tensor3 const& tCsB, inline __device__ void gemm_A_in_regs(Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCrB, Tensor3 const& tCsB,
TiledMma tiled_mma, TiledCopy smem_thr_copy_B) { TiledMma tiled_mma, TiledCopy smem_tiled_copy_B,
ThrCopy smem_thr_copy_B) {
CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(acc)); // MMA_M CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(acc)); // MMA_M
CUTE_STATIC_ASSERT_V(size<1>(tCrB) == size<2>(acc)); // MMA_N CUTE_STATIC_ASSERT_V(size<1>(tCrB) == size<2>(acc)); // MMA_N
CUTE_STATIC_ASSERT_V(size<2>(tCrA) == size<2>(tCrB)); // MMA_K CUTE_STATIC_ASSERT_V(size<2>(tCrA) == size<2>(tCrB)); // MMA_K
Tensor tCrB_copy_view = smem_thr_copy_B.retile_D(tCrB); Tensor tCrB_copy_view = smem_thr_copy_B.retile_D(tCrB);
CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<1>(tCrB_copy_view)); // N CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<1>(tCrB_copy_view)); // N
copy(smem_thr_copy_B, tCsB(_, _, _0{}), tCrB_copy_view(_, _, _0{})); cute::copy(smem_tiled_copy_B, tCsB(_, _, _0{}), tCrB_copy_view(_, _, _0{}));
#pragma unroll #pragma unroll
for (int i = 0; i < size<2>(tCrA); ++i) { for (int i = 0; i < size<2>(tCrA); ++i) {
if (i < size<2>(tCrA) - 1) { if (i < size<2>(tCrA) - 1) {
copy(smem_thr_copy_B, tCsB(_, _, i + 1), tCrB_copy_view(_, _, i + 1)); cute::copy(smem_tiled_copy_B, tCsB(_, _, i + 1), tCrB_copy_view(_, _, i + 1));
} }
cute::gemm(tiled_mma, tCrA(_, _, i), tCrB(_, _, i), acc); cute::gemm(tiled_mma, tCrA(_, _, i), tCrB(_, _, i), acc);
} }
@ -225,7 +188,10 @@ inline __device__ auto convert_layout_acc_rowcol(Layout acc_layout) {
static_assert(decltype(size<0>(acc_layout))::value == 4); static_assert(decltype(size<0>(acc_layout))::value == 4);
static_assert(decltype(rank(acc_layout))::value == 3); static_assert(decltype(rank(acc_layout))::value == 3);
auto l = logical_divide(acc_layout, Shape<_2>{}); // ((2, 2), MMA_M, MMA_N) auto l = logical_divide(acc_layout, Shape<_2>{}); // ((2, 2), MMA_M, MMA_N)
return make_layout(make_layout(get<0, 1>(l), get<1>(l)), make_layout(get<0, 0>(l), get<2>(l))); // TD [2023-08-13]: Idk why but get<0, 1>(l) doesn't work for Cutlass 3.2, I'm getting
// "int_tuple.hpp(74): error: conversion to inaccessible base class"
// return make_layout(make_layout(get<0, 1>(l), get<1>(l)), make_layout(get<0, 0>(l), get<2>(l)));
return make_layout(make_layout(get<1>(get<0>(l)), get<1>(l)), make_layout(get<0>(get<0>(l)), get<2>(l)));
}; };
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
@ -241,9 +207,13 @@ inline __device__ auto convert_layout_rowcol_Aregs(Layout rowcol_layout) {
static_assert(mma_shape_K == 8 || mma_shape_K == 16); static_assert(mma_shape_K == 8 || mma_shape_K == 16);
constexpr int MMA_N_divisor = mma_shape_K == 8 ? 1 : 2; constexpr int MMA_N_divisor = mma_shape_K == 8 ? 1 : 2;
auto l = logical_divide(rowcol_layout, Shape<X, Shape<X, Int<MMA_N_divisor>>>{}); // ((2, MMA_M), (2, (2, MMA_N / 2))) auto l = logical_divide(rowcol_layout, Shape<X, Shape<X, Int<MMA_N_divisor>>>{}); // ((2, MMA_M), (2, (2, MMA_N / 2)))
return make_layout(make_layout(get<1, 0>(l), get<0, 0>(l), get<1, 1, 0>(l)), // TD [2023-08-13]: Same error as above on Cutlass 3.2
get<0, 1>(l), // return make_layout(make_layout(get<1, 0>(l), get<0, 0>(l), get<1, 1, 0>(l)),
get<1, 1, 1>(l)); // get<0, 1>(l),
// get<1, 1, 1>(l));
return make_layout(make_layout(get<0>(get<1>(l)), get<0>(get<0>(l)), get<0>(get<1>(get<1>(l)))),
get<1>(get<0>(l)),
get<1>(get<1>(get<1>(l))));
}; };
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
@ -319,9 +289,9 @@ void cp_async_wait() {
template <bool Is_even_MN=true, bool Is_even_K=true, bool Clear_OOB_MN=false, bool Clear_OOB_K=true, template <bool Is_even_MN=true, bool Is_even_K=true, bool Clear_OOB_MN=false, bool Clear_OOB_K=true,
typename TiledCopy, typename Engine0, typename Layout0, typename Engine1, typename Layout1, typename TiledCopy, typename Engine0, typename Layout0, typename Engine1, typename Layout1,
typename Engine2, typename Layout2, typename Engine3, typename Layout3> typename Engine2, typename Layout2, typename Engine3, typename Layout3>
inline __device__ void copy(TiledCopy thr_copy, Tensor<Engine0, Layout0> const &S, inline __device__ void copy(TiledCopy tiled_copy, Tensor<Engine0, Layout0> const &S,
Tensor<Engine1, Layout1> &D, Tensor<Engine2, Layout2> const &identity_MN, Tensor<Engine1, Layout1> &D, Tensor<Engine2, Layout2> const &identity_MN,
Tensor<Engine3, Layout3> const &predicate_K, int max_MN=0) { Tensor<Engine3, Layout3> const &predicate_K, const int max_MN=0) {
CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{}); CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{});
CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{}); CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{});
CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D)); // MMA CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D)); // MMA
@ -335,13 +305,13 @@ inline __device__ void copy(TiledCopy thr_copy, Tensor<Engine0, Layout0> const &
#pragma unroll #pragma unroll
for (int k = 0; k < size<2>(S); ++k) { for (int k = 0; k < size<2>(S); ++k) {
if (Is_even_K || predicate_K(k)) { if (Is_even_K || predicate_K(k)) {
copy(thr_copy, S(_, m, k), D(_, m, k)); cute::copy(tiled_copy, S(_, m, k), D(_, m, k));
} else if (Clear_OOB_K) { } else if (Clear_OOB_K) {
clear(D(_, m, k)); cute::clear(D(_, m, k));
} }
} }
} else if (Clear_OOB_MN) { } else if (Clear_OOB_MN) {
clear(D(_, m, _)); cute::clear(D(_, m, _));
} }
} }
// TD [2023-04-13]: Strange that the code below can cause race condition. // TD [2023-04-13]: Strange that the code below can cause race condition.
@ -350,7 +320,7 @@ inline __device__ void copy(TiledCopy thr_copy, Tensor<Engine0, Layout0> const &
// #pragma unroll // #pragma unroll
// for (int m = 0; m < size<1>(S); ++m) { // for (int m = 0; m < size<1>(S); ++m) {
// if (Is_even_MN || get<0>(identity_MN(0, m, 0)) < max_MN) { // if (Is_even_MN || get<0>(identity_MN(0, m, 0)) < max_MN) {
// copy(thr_copy, S(_, m, _), D(_, m, _)); // copy(tiled_copy, S(_, m, _), D(_, m, _));
// } else if (Clear_OOB_MN) { // } else if (Clear_OOB_MN) {
// clear(D(_, m, _)); // clear(D(_, m, _));
// } // }
@ -362,7 +332,7 @@ inline __device__ void copy(TiledCopy thr_copy, Tensor<Engine0, Layout0> const &
// #pragma unroll // #pragma unroll
// for (int m = 0; m < size<1>(S); ++m) { // for (int m = 0; m < size<1>(S); ++m) {
// if (Is_even_MN || get<0>(identity_MN(0, m, 0)) < max_MN) { // if (Is_even_MN || get<0>(identity_MN(0, m, 0)) < max_MN) {
// copy(thr_copy, S(_, m, k), D(_, m, k)); // copy(tiled_copy, S(_, m, k), D(_, m, k));
// } else if (Clear_OOB_MN) { // } else if (Clear_OOB_MN) {
// clear(D(_, m, k)); // clear(D(_, m, k));
// } // }

View File

@ -7,6 +7,8 @@ extern "C" {
v_ptr: *const c_void, v_ptr: *const c_void,
o_ptr: *const c_void, o_ptr: *const c_void,
softmax_lse_ptr: *const c_void, softmax_lse_ptr: *const c_void,
alibi_slopes_ptr: *const c_void,
cu_seqlens_q_ptr: *const i32, cu_seqlens_q_ptr: *const i32,
cu_seqlens_k_ptr: *const i32, cu_seqlens_k_ptr: *const i32,
@ -14,6 +16,7 @@ extern "C" {
k_batch_stride: u32, k_batch_stride: u32,
v_batch_stride: u32, v_batch_stride: u32,
o_batch_stride: u32, o_batch_stride: u32,
alibi_slopes_batch_stride: u32,
q_row_stride: u32, q_row_stride: u32,
k_row_stride: u32, k_row_stride: u32,
@ -37,8 +40,11 @@ extern "C" {
seqlen_q_rounded: u32, seqlen_q_rounded: u32,
seqlen_k_rounded: u32, seqlen_k_rounded: u32,
is_causal: c_int,
is_bf16: c_int, is_bf16: c_int,
is_causal: c_int,
window_size_left: c_int,
window_size_right: c_int,
); );
} }

View File

@ -3,12 +3,14 @@ mod ffi;
use candle::backend::BackendStorage; use candle::backend::BackendStorage;
use candle::cuda_backend::cudarc::driver::DevicePtr; use candle::cuda_backend::cudarc::driver::DevicePtr;
use candle::cuda_backend::WrapErr; use candle::cuda_backend::WrapErr;
use candle::{CpuStorage, Layout, Result, Shape, Tensor}; use candle::{CpuStorage, DType, Layout, Result, Shape, Tensor};
use half::{bf16, f16}; use half::{bf16, f16};
pub struct FlashAttn { pub struct FlashAttn {
pub softmax_scale: f32, pub softmax_scale: f32,
pub causal: bool, pub alibi_slopes: Option<Tensor>,
pub window_size_left: Option<usize>,
pub window_size_right: Option<usize>,
} }
fn round_multiple(x: usize, m: usize) -> usize { fn round_multiple(x: usize, m: usize) -> usize {
@ -85,6 +87,51 @@ impl FlashAttn {
candle::bail!("number of k/v heads {num_heads_k} must divide number of heads in query {num_heads}") candle::bail!("number of k/v heads {num_heads_k} must divide number of heads in query {num_heads}")
} }
let alibi_slopes_ptr = if let Some(alibi_slopes) = &self.alibi_slopes {
if alibi_slopes.dtype() != DType::F32 {
candle::bail!(
"DType mismatch alibi_slopes {:?}, expected {:?}",
alibi_slopes.dtype(),
DType::F32
);
}
let (alibi_slopes, alibi_slopes_layout) = alibi_slopes.storage_and_layout();
if num_heads != alibi_slopes_layout.shape().dims1()? {
candle::bail!(
"shape mismatch alibi_slopes {:?}, expected {:?}",
alibi_slopes_layout.shape(),
(num_heads)
);
}
let alibi_slopes = match &*alibi_slopes {
candle::Storage::Cuda(c) => c.as_cuda_slice::<f32>()?,
_ => candle::bail!("alibi_slopes must be a cuda tensor"),
};
let alibi_slopes = alibi_slopes.slice(alibi_slopes_layout.start_offset()..);
*alibi_slopes.device_ptr() as *const core::ffi::c_void
} else {
std::ptr::null()
};
// if window_size_left > self.max_seqlen_k or None => -1
let mut window_size_left = self
.window_size_left
.filter(|v| v <= &seqlen_k)
.map(|v| v as i32)
.unwrap_or(-1);
// if window_size_right > self.max_seqlen_k or None => -1
let mut window_size_right = self
.window_size_right
.filter(|v| v <= &seqlen_k)
.map(|v| v as i32)
.unwrap_or(-1);
let head_size = round_multiple(head_size_og, 8); let head_size = round_multiple(head_size_og, 8);
let head_size_rounded = round_multiple(head_size, 32); let head_size_rounded = round_multiple(head_size, 32);
let seqlen_q_rounded = round_multiple(seqlen_q, 128); let seqlen_q_rounded = round_multiple(seqlen_q, 128);
@ -94,9 +141,22 @@ impl FlashAttn {
let dst = unsafe { dev.alloc::<T>(elem_count) }.w()?; let dst = unsafe { dev.alloc::<T>(elem_count) }.w()?;
let softmax_lse = dev.alloc_zeros::<f32>(b_sz * num_heads * seqlen_q).w()?; let softmax_lse = dev.alloc_zeros::<f32>(b_sz * num_heads * seqlen_q).w()?;
let causal = if self.causal { 1 } else { 0 };
let is_bf16 = if is_bf16 { 1 } else { 0 }; let is_bf16 = if is_bf16 { 1 } else { 0 };
// Causal is the special case where window_size_right == 0 and window_size_left < 0.
// Local is the more general case where window_size_right >= 0 or window_size_left >= 0.
let is_causal = if window_size_left < 0 && window_size_right == 0 {
1
} else {
0
};
if window_size_left < 0 && window_size_right >= 0 {
window_size_left = seqlen_k as i32;
}
if window_size_left >= 0 && window_size_right < 0 {
window_size_right = seqlen_k as i32;
}
unsafe { unsafe {
let q_ptr = *q.device_ptr() as *const core::ffi::c_void; let q_ptr = *q.device_ptr() as *const core::ffi::c_void;
let k_ptr = *k.device_ptr() as *const core::ffi::c_void; let k_ptr = *k.device_ptr() as *const core::ffi::c_void;
@ -109,12 +169,14 @@ impl FlashAttn {
v_ptr, v_ptr,
dst_ptr, dst_ptr,
softmax_lse_ptr, softmax_lse_ptr,
/* alibi_slopes_ptr */ alibi_slopes_ptr,
/* cu_seqlens_q_ptr */ std::ptr::null(), /* cu_seqlens_q_ptr */ std::ptr::null(),
/* cu_seqlens_k_ptr */ std::ptr::null(), /* cu_seqlens_k_ptr */ std::ptr::null(),
/* q_batch_stride */ q_stride[0] as u32, /* q_batch_stride */ q_stride[0] as u32,
/* k_batch_stride */ k_stride[0] as u32, /* k_batch_stride */ k_stride[0] as u32,
/* v_batch_stride */ v_stride[0] as u32, /* v_batch_stride */ v_stride[0] as u32,
/* o_batch_stride */ o_stride[0] as u32, /* o_batch_stride */ o_stride[0] as u32,
/* alibi_slopes_batch_stride */ 0,
/* q_row_stride */ q_stride[q_rank - 3] as u32, /* q_row_stride */ q_stride[q_rank - 3] as u32,
/* k_row_stride */ k_stride[k_rank - 3] as u32, /* k_row_stride */ k_stride[k_rank - 3] as u32,
/* v_row_stride */ v_stride[v_rank - 3] as u32, /* v_row_stride */ v_stride[v_rank - 3] as u32,
@ -133,8 +195,10 @@ impl FlashAttn {
/* seqlen_k */ seqlen_k as u32, /* seqlen_k */ seqlen_k as u32,
/* seqlen_q_rounded */ seqlen_q_rounded as u32, /* seqlen_q_rounded */ seqlen_q_rounded as u32,
/* seqlen_k_rounded */ seqlen_k_rounded as u32, /* seqlen_k_rounded */ seqlen_k_rounded as u32,
/* is_causal */ causal,
/* is_bf16 */ is_bf16, /* is_bf16 */ is_bf16,
/* is_causal */ is_causal,
/* window_size_left */ window_size_left,
/* window_size_right */ window_size_right,
) )
} }
@ -197,20 +261,137 @@ pub fn flash_attn(
softmax_scale: f32, softmax_scale: f32,
causal: bool, causal: bool,
) -> Result<Tensor> { ) -> Result<Tensor> {
let window_size_left = None;
let window_size_right = if causal { Some(0) } else { None };
let op = FlashAttn { let op = FlashAttn {
softmax_scale, softmax_scale,
causal, alibi_slopes: None,
window_size_left,
window_size_right,
};
q.apply_op3(k, v, op)
}
/// Flash-attention v2 layer.
///
/// This implements scaled dot-product attention, `softmax(Q @ K^T . softmax_scale) @ V`.
/// Multi-query and grouped-query attention are supported by using tensors k and v with fewer heads
/// than q, the number of heads in k and v has to be divisible by the number of heads in q.
///
/// # Arguments
///
/// * `q` - Query tensor with shape `(batch, seq_len_q, num_heads_q, head_size)`.
/// * `k` - Key tensor with shape `(batch, seq_len_kv, num_heads_kv, head_size)`.
/// * `v` - Value tensor with shape `(batch, seq_len_kv, num_heads_kv, head_size)`.
/// * `window_size_left` - Limit left attention to value tokens.
/// * `window_size_right` - Limit right attention to value tokens.
///
/// # Causal mask
///
/// `window_size_left=None` with `window_size_right=Some(0)` applies a causal mask to the result
/// of `Q @ K^T`
///
/// The resulting tensor has dimensions `(batch, seq_len_q, num_heads_q, head_size)`.
pub fn flash_attn_windowed(
q: &Tensor,
k: &Tensor,
v: &Tensor,
softmax_scale: f32,
window_size_left: Option<usize>,
window_size_right: Option<usize>,
) -> Result<Tensor> {
let op = FlashAttn {
softmax_scale,
alibi_slopes: None,
window_size_left,
window_size_right,
};
q.apply_op3(k, v, op)
}
/// Flash-attention v2 layer.
///
/// This implements scaled dot-product attention, `softmax(Q @ K^T . softmax_scale) @ V`.
/// Multi-query and grouped-query attention are supported by using tensors k and v with fewer heads
/// than q, the number of heads in k and v has to be divisible by the number of heads in q.
///
/// # Arguments
///
/// * `q` - Query tensor with shape `(batch, seq_len_q, num_heads_q, head_size)`.
/// * `k` - Key tensor with shape `(batch, seq_len_kv, num_heads_kv, head_size)`.
/// * `v` - Value tensor with shape `(batch, seq_len_kv, num_heads_kv, head_size)`.
/// * `alibi_slopes` - Alibi slopes tensor with shape `(num_heads_q)`.
///
/// The resulting tensor has dimensions `(batch, seq_len_q, num_heads_q, head_size)`.
pub fn flash_attn_alibi(
q: &Tensor,
k: &Tensor,
v: &Tensor,
alibi_slopes: &Tensor,
softmax_scale: f32,
causal: bool,
) -> Result<Tensor> {
let window_size_left = None;
let window_size_right = if causal { Some(0) } else { None };
let op = FlashAttn {
softmax_scale,
alibi_slopes: Some(alibi_slopes.clone()),
window_size_left,
window_size_right,
};
q.apply_op3(k, v, op)
}
/// Flash-attention v2 layer.
///
/// This implements scaled dot-product attention, `softmax(Q @ K^T . softmax_scale) @ V`.
/// Multi-query and grouped-query attention are supported by using tensors k and v with fewer heads
/// than q, the number of heads in k and v has to be divisible by the number of heads in q.
///
/// # Arguments
///
/// * `q` - Query tensor with shape `(batch, seq_len_q, num_heads_q, head_size)`.
/// * `k` - Key tensor with shape `(batch, seq_len_kv, num_heads_kv, head_size)`.
/// * `v` - Value tensor with shape `(batch, seq_len_kv, num_heads_kv, head_size)`.
/// * `alibi_slopes` - Alibi slopes tensor with shape `(num_heads_q)`.
/// * `window_size_left` - Limit left attention to value tokens.
/// * `window_size_right` - Limit right attention to value tokens.
///
/// # Causal mask
///
/// `window_size_left=None` with `window_size_right=Some(0)` applies a causal mask to the result
/// of `Q @ K^T`
///
/// The resulting tensor has dimensions `(batch, seq_len_q, num_heads_q, head_size)`.
pub fn flash_attn_alibi_windowed(
q: &Tensor,
k: &Tensor,
v: &Tensor,
alibi_slopes: &Tensor,
softmax_scale: f32,
window_size_left: Option<usize>,
window_size_right: Option<usize>,
) -> Result<Tensor> {
let op = FlashAttn {
softmax_scale,
alibi_slopes: Some(alibi_slopes.clone()),
window_size_left,
window_size_right,
}; };
q.apply_op3(k, v, op) q.apply_op3(k, v, op)
} }
struct FlashAttnVarLen { struct FlashAttnVarLen {
softmax_scale: f32, pub softmax_scale: f32,
causal: bool, pub max_seqlen_q: usize,
max_seqlen_q: usize, pub max_seqlen_k: usize,
max_seqlen_k: usize, pub seqlens_q: Tensor,
seqlens_q: Tensor, pub seqlens_k: Tensor,
seqlens_k: Tensor, pub alibi_slopes: Option<Tensor>,
pub window_size_left: Option<usize>,
pub window_size_right: Option<usize>,
} }
impl FlashAttnVarLen { impl FlashAttnVarLen {
@ -311,7 +492,54 @@ impl FlashAttnVarLen {
if nseqlens_k != nseqlens_q { if nseqlens_k != nseqlens_q {
candle::bail!("seqlens_q and seqlens_k should have the same number of elements {nseqlens_q} <> {nseqlens_k}") candle::bail!("seqlens_q and seqlens_k should have the same number of elements {nseqlens_q} <> {nseqlens_k}")
} }
let batch_size = nseqlens_q - 1; let batch_size = nseqlens_q - 1;
let alibi_slopes_ptr = if let Some(alibi_slopes) = &self.alibi_slopes {
if alibi_slopes.dtype() != DType::F32 {
candle::bail!(
"DType mismatch alibi_slopes {:?}, expected {:?}",
alibi_slopes.dtype(),
DType::F32
);
}
let (alibi_slopes, alibi_slopes_layout) = alibi_slopes.storage_and_layout();
if num_heads != alibi_slopes_layout.shape().dims1()? {
candle::bail!(
"shape mismatch alibi_slopes {:?}, expected {:?}",
alibi_slopes_layout.shape(),
(num_heads)
);
}
let alibi_slopes = match &*alibi_slopes {
candle::Storage::Cuda(c) => c.as_cuda_slice::<f32>()?,
_ => candle::bail!("alibi_slopes must be a cuda tensor"),
};
let alibi_slopes = alibi_slopes.slice(alibi_slopes_layout.start_offset()..);
*alibi_slopes.device_ptr() as *const core::ffi::c_void
} else {
std::ptr::null()
};
// if window_size_left > self.max_seqlen_k or None => -1
let mut window_size_left = self
.window_size_left
.filter(|v| v <= &self.max_seqlen_k)
.map(|v| v as i32)
.unwrap_or(-1);
// if window_size_right > self.max_seqlen_k or None => -1
let mut window_size_right = self
.window_size_right
.filter(|v| v <= &self.max_seqlen_k)
.map(|v| v as i32)
.unwrap_or(-1);
let head_size = round_multiple(head_size_og, 8); let head_size = round_multiple(head_size_og, 8);
let head_size_rounded = round_multiple(head_size, 32); let head_size_rounded = round_multiple(head_size, 32);
let seqlen_q_rounded = round_multiple(self.max_seqlen_q, 128); let seqlen_q_rounded = round_multiple(self.max_seqlen_q, 128);
@ -323,9 +551,22 @@ impl FlashAttnVarLen {
.alloc_zeros::<f32>(batch_size * num_heads * self.max_seqlen_q) .alloc_zeros::<f32>(batch_size * num_heads * self.max_seqlen_q)
.w()?; .w()?;
let causal = if self.causal { 1 } else { 0 };
let is_bf16 = if is_bf16 { 1 } else { 0 }; let is_bf16 = if is_bf16 { 1 } else { 0 };
// Causal is the special case where window_size_right == 0 and window_size_left < 0.
// Local is the more general case where window_size_right >= 0 or window_size_left >= 0.
let is_causal = if window_size_left < 0 && window_size_right == 0 {
1
} else {
0
};
if window_size_left < 0 && window_size_right >= 0 {
window_size_left = self.max_seqlen_k as i32;
}
if window_size_left >= 0 && window_size_right < 0 {
window_size_right = self.max_seqlen_k as i32;
}
unsafe { unsafe {
let q_ptr = *q.device_ptr() as *const core::ffi::c_void; let q_ptr = *q.device_ptr() as *const core::ffi::c_void;
let k_ptr = *k.device_ptr() as *const core::ffi::c_void; let k_ptr = *k.device_ptr() as *const core::ffi::c_void;
@ -340,12 +581,14 @@ impl FlashAttnVarLen {
v_ptr, v_ptr,
dst_ptr, dst_ptr,
softmax_lse_ptr, softmax_lse_ptr,
/* alibi_slopes_ptr */ alibi_slopes_ptr,
/* cu_seqlens_q_ptr */ seqlens_q_ptr, /* cu_seqlens_q_ptr */ seqlens_q_ptr,
/* cu_seqlens_k_ptr */ seqlens_k_ptr, /* cu_seqlens_k_ptr */ seqlens_k_ptr,
/* q_batch_stride */ 0, /* q_batch_stride */ 0,
/* k_batch_stride */ 0, /* k_batch_stride */ 0,
/* v_batch_stride */ 0, /* v_batch_stride */ 0,
/* o_batch_stride */ 0, /* o_batch_stride */ 0,
/* alibi_slopes_batch_stride */ 0,
/* q_row_stride */ q_stride[q_rank - 3] as u32, /* q_row_stride */ q_stride[q_rank - 3] as u32,
/* k_row_stride */ k_stride[k_rank - 3] as u32, /* k_row_stride */ k_stride[k_rank - 3] as u32,
/* v_row_stride */ v_stride[v_rank - 3] as u32, /* v_row_stride */ v_stride[v_rank - 3] as u32,
@ -364,8 +607,10 @@ impl FlashAttnVarLen {
/* seqlen_k */ self.max_seqlen_k as u32, /* seqlen_k */ self.max_seqlen_k as u32,
/* seqlen_q_rounded */ seqlen_q_rounded as u32, /* seqlen_q_rounded */ seqlen_q_rounded as u32,
/* seqlen_k_rounded */ seqlen_k_rounded as u32, /* seqlen_k_rounded */ seqlen_k_rounded as u32,
/* is_causal */ causal,
/* is_bf16 */ is_bf16, /* is_bf16 */ is_bf16,
/* is_causal */ is_causal,
/* window_size_left */ window_size_left,
/* window_size_right */ window_size_right,
) )
} }
@ -440,13 +685,176 @@ pub fn flash_attn_varlen(
softmax_scale: f32, softmax_scale: f32,
causal: bool, causal: bool,
) -> Result<Tensor> { ) -> Result<Tensor> {
let window_size_left = None;
let window_size_right = if causal { Some(0) } else { None };
let op = FlashAttnVarLen { let op = FlashAttnVarLen {
softmax_scale, softmax_scale,
causal,
max_seqlen_q, max_seqlen_q,
max_seqlen_k, max_seqlen_k,
seqlens_q: seqlens_q.clone(), seqlens_q: seqlens_q.clone(),
seqlens_k: seqlens_k.clone(), seqlens_k: seqlens_k.clone(),
alibi_slopes: None,
window_size_left,
window_size_right,
};
q.apply_op3(k, v, op)
}
#[allow(clippy::too_many_arguments)]
/// Flash-attention v2 layer with variable-length batching.
///
/// This implements scaled dot-product attention, `softmax(Q @ K^T . softmax_scale) @ V`.
/// Multi-query and grouped-query attention are supported by using tensors k and v with fewer heads
/// than q, the number of heads in k and v has to be divisible by the number of heads in q.
///
/// # Arguments
///
/// * `q` - Query tensor with shape `(total_q, num_heads_q, head_size)`.
/// * `k` - Key tensor with shape `(total_kv, num_heads_kv, head_size)`.
/// * `v` - Value tensor with shape `(total_kv, num_heads_kv, head_size)`.
/// * `seqlens_q` - The cumulative lengths of the sequences in the batch, used to index in q.
/// * `seqlens_k` - The cumulative lengths of the sequences in the batch, used to index in k and v.
/// * `max_seqlen_q` - The maximum query sequence length for q in the batch.
/// * `max_seqlen_k` - The maximum query sequence length for k and v in the batch.
/// * `window_size_left` - Limit left attention to value tokens.
/// * `window_size_right` - Limit right attention to value tokens.
///
/// `seqlens_q` and `seqlens_k` contain `batch_size + 1` elements, typically `0`, `seqlen_1`,
/// `seqlen_1 + seqlen_2`, etc.
///
/// The resulting tensor has dimensions `(total_q, num_heads_q, head_size)`.
///
/// # Causal mask
///
/// `window_size_left=None` with `window_size_right=Some(0)` applies a causal mask to the result
/// of `Q @ K^T`
pub fn flash_attn_varlen_windowed(
q: &Tensor,
k: &Tensor,
v: &Tensor,
seqlens_q: &Tensor,
seqlens_k: &Tensor,
max_seqlen_q: usize,
max_seqlen_k: usize,
softmax_scale: f32,
window_size_left: Option<usize>,
window_size_right: Option<usize>,
) -> Result<Tensor> {
let op = FlashAttnVarLen {
softmax_scale,
max_seqlen_q,
max_seqlen_k,
seqlens_q: seqlens_q.clone(),
seqlens_k: seqlens_k.clone(),
alibi_slopes: None,
window_size_left,
window_size_right,
};
q.apply_op3(k, v, op)
}
#[allow(clippy::too_many_arguments)]
/// Flash-attention v2 layer with variable-length batching.
///
/// This implements scaled dot-product attention, `softmax(Q @ K^T . softmax_scale) @ V`.
/// Multi-query and grouped-query attention are supported by using tensors k and v with fewer heads
/// than q, the number of heads in k and v has to be divisible by the number of heads in q.
///
/// # Arguments
///
/// * `q` - Query tensor with shape `(total_q, num_heads_q, head_size)`.
/// * `k` - Key tensor with shape `(total_kv, num_heads_kv, head_size)`.
/// * `v` - Value tensor with shape `(total_kv, num_heads_kv, head_size)`.
/// * `alibi_slopes` - Alibi slopes tensor with shape `(num_heads_q)`.
/// * `seqlens_q` - The cumulative lengths of the sequences in the batch, used to index in q.
/// * `seqlens_k` - The cumulative lengths of the sequences in the batch, used to index in k and v.
/// * `max_seqlen_q` - The maximum query sequence length for q in the batch.
/// * `max_seqlen_k` - The maximum query sequence length for k and v in the batch.
///
/// `seqlens_q` and `seqlens_k` contain `batch_size + 1` elements, typically `0`, `seqlen_1`,
/// `seqlen_1 + seqlen_2`, etc.
///
/// The resulting tensor has dimensions `(total_q, num_heads_q, head_size)`.
pub fn flash_attn_varlen_alibi(
q: &Tensor,
k: &Tensor,
v: &Tensor,
alibi_slopes: &Tensor,
seqlens_q: &Tensor,
seqlens_k: &Tensor,
max_seqlen_q: usize,
max_seqlen_k: usize,
softmax_scale: f32,
causal: bool,
) -> Result<Tensor> {
let window_size_left = None;
let window_size_right = if causal { Some(0) } else { None };
let op = FlashAttnVarLen {
softmax_scale,
max_seqlen_q,
max_seqlen_k,
seqlens_q: seqlens_q.clone(),
seqlens_k: seqlens_k.clone(),
alibi_slopes: Some(alibi_slopes.clone()),
window_size_left,
window_size_right,
};
q.apply_op3(k, v, op)
}
#[allow(clippy::too_many_arguments)]
/// Flash-attention v2 layer with variable-length batching.
///
/// This implements scaled dot-product attention, `softmax(Q @ K^T . softmax_scale) @ V`.
/// Multi-query and grouped-query attention are supported by using tensors k and v with fewer heads
/// than q, the number of heads in k and v has to be divisible by the number of heads in q.
///
/// # Arguments
///
/// * `q` - Query tensor with shape `(total_q, num_heads_q, head_size)`.
/// * `k` - Key tensor with shape `(total_kv, num_heads_kv, head_size)`.
/// * `v` - Value tensor with shape `(total_kv, num_heads_kv, head_size)`.
/// * `alibi_slopes` - Alibi slopes tensor with shape `(num_heads_q)`.
/// * `seqlens_q` - The cumulative lengths of the sequences in the batch, used to index in q.
/// * `seqlens_k` - The cumulative lengths of the sequences in the batch, used to index in k and v.
/// * `max_seqlen_q` - The maximum query sequence length for q in the batch.
/// * `max_seqlen_k` - The maximum query sequence length for k and v in the batch.
/// * `window_size_left` - Limit left attention to value tokens.
/// * `window_size_right` - Limit right attention to value tokens.
///
/// `seqlens_q` and `seqlens_k` contain `batch_size + 1` elements, typically `0`, `seqlen_1`,
/// `seqlen_1 + seqlen_2`, etc.
///
/// The resulting tensor has dimensions `(total_q, num_heads_q, head_size)`.
///
/// # Causal mask
///
/// `window_size_left=None` with `window_size_right=Some(0)` applies a causal mask to the result
/// of `Q @ K^T`
pub fn flash_attn_varlen_alibi_windowed(
q: &Tensor,
k: &Tensor,
v: &Tensor,
alibi_slopes: &Tensor,
seqlens_q: &Tensor,
seqlens_k: &Tensor,
max_seqlen_q: usize,
max_seqlen_k: usize,
softmax_scale: f32,
window_size_left: Option<usize>,
window_size_right: Option<usize>,
) -> Result<Tensor> {
let op = FlashAttnVarLen {
softmax_scale,
max_seqlen_q,
max_seqlen_k,
seqlens_q: seqlens_q.clone(),
seqlens_k: seqlens_k.clone(),
alibi_slopes: Some(alibi_slopes.clone()),
window_size_left,
window_size_right,
}; };
q.apply_op3(k, v, op) q.apply_op3(k, v, op)
} }

View File

@ -12,6 +12,4 @@ license = "MIT OR Apache-2.0"
[dependencies] [dependencies]
[build-dependencies] [build-dependencies]
anyhow = { version = "1", features = ["backtrace"] } bindgen_cuda = "0.1.1"
glob = "0.3.1"
rayon = "1.7.0"

View File

@ -1,243 +1,8 @@
use std::io::Write;
fn main() { fn main() {
println!("cargo:rerun-if-changed=build.rs"); println!("cargo:rerun-if-changed=build.rs");
cuda::set_include_dir(); let builder = bindgen_cuda::Builder::default();
let (write, kernel_paths) = cuda::build_ptx(); println!("cargo:info={builder:?}");
if write { let bindings = builder.build_ptx().unwrap();
let mut file = std::fs::File::create("src/lib.rs").unwrap(); bindings.write("src/lib.rs").unwrap();
for kernel_path in kernel_paths {
let name = kernel_path.file_stem().unwrap().to_str().unwrap();
file.write_all(
format!(
r#"pub const {}: &str = include_str!(concat!(env!("OUT_DIR"), "/{}.ptx"));"#,
name.to_uppercase().replace('.', "_"),
name
)
.as_bytes(),
)
.unwrap();
file.write_all(&[b'\n']).unwrap();
}
}
}
mod cuda {
use anyhow::{Context, Result};
pub fn set_include_dir() {
use std::path::PathBuf;
// NOTE: copied from cudarc build.rs.
// We can't actually set a env!() value from another crate,
// so we have to do that here.
// use PathBuf;
let env_vars = [
"CUDA_PATH",
"CUDA_ROOT",
"CUDA_TOOLKIT_ROOT_DIR",
"CUDNN_LIB",
];
#[allow(unused)]
let env_vars = env_vars
.into_iter()
.map(std::env::var)
.filter_map(Result::ok)
.map(Into::<PathBuf>::into);
let roots = [
"/usr",
"/usr/local/cuda",
"/opt/cuda",
"/usr/lib/cuda",
"C:/Program Files/NVIDIA GPU Computing Toolkit",
"C:/CUDA",
];
#[allow(unused)]
let roots = roots.into_iter().map(Into::<PathBuf>::into);
#[cfg(feature = "ci-check")]
let root: PathBuf = "ci".into();
#[cfg(not(feature = "ci-check"))]
let root = env_vars
.chain(roots)
.find(|path| path.join("include").join("cuda.h").is_file())
.unwrap();
println!(
"cargo:rustc-env=CUDA_INCLUDE_DIR={}",
root.join("include").display()
);
}
pub fn build_ptx() -> (bool, Vec<std::path::PathBuf>) {
use rayon::prelude::*;
use std::path::PathBuf;
let out_dir = std::env::var("OUT_DIR").unwrap();
let kernel_paths: Vec<PathBuf> = glob::glob("src/*.cu")
.unwrap()
.map(|p| p.unwrap())
.collect();
let mut include_directories: Vec<PathBuf> = glob::glob("src/**/*.cuh")
.unwrap()
.map(|p| p.unwrap())
.collect();
println!("cargo:rerun-if-changed=src/");
// for path in &kernel_paths {
// println!("cargo:rerun-if-changed={}", path.display());
// }
for path in &mut include_directories {
// println!("cargo:rerun-if-changed={}", path.display());
let destination =
std::format!("{out_dir}/{}", path.file_name().unwrap().to_str().unwrap());
std::fs::copy(path.clone(), destination).unwrap();
// remove the filename from the path so it's just the directory
path.pop();
}
include_directories.sort();
include_directories.dedup();
let compute_cap = compute_cap().expect("Could not get Cuda compute cap");
#[allow(unused)]
let include_options: Vec<String> = include_directories
.into_iter()
.map(|s| "-I".to_string() + &s.into_os_string().into_string().unwrap())
.collect::<Vec<_>>();
let ccbin_env = std::env::var("CANDLE_NVCC_CCBIN");
println!("cargo:rerun-if-env-changed=CANDLE_NVCC_CCBIN");
let children = kernel_paths
.par_iter()
.flat_map(|p| {
let mut output = p.clone();
output.set_extension("ptx");
let output_filename = std::path::Path::new(&out_dir).to_path_buf().join("out").with_file_name(output.file_name().unwrap());
let ignore = if output_filename.exists() {
let out_modified = output_filename.metadata().unwrap().modified().unwrap();
let in_modified = p.metadata().unwrap().modified().unwrap();
out_modified.duration_since(in_modified).is_ok()
} else {
false
};
if ignore {
None
} else {
let mut command = std::process::Command::new("nvcc");
command.arg(format!("--gpu-architecture=sm_{compute_cap}"))
.arg("--ptx")
.args(["--default-stream", "per-thread"])
.args(["--output-directory", &out_dir])
// Flash attention only
// .arg("--expt-relaxed-constexpr")
.args(&include_options);
if let Ok(ccbin_path) = &ccbin_env {
command
.arg("-allow-unsupported-compiler")
.args(["-ccbin", ccbin_path]);
}
command.arg(p);
Some((p, command.spawn()
.expect("nvcc failed to start. Ensure that you have CUDA installed and that `nvcc` is in your PATH.").wait_with_output()))
}
})
.collect::<Vec<_>>();
let ptx_paths: Vec<PathBuf> = glob::glob(&format!("{out_dir}/**/*.ptx"))
.unwrap()
.map(|p| p.unwrap())
.collect();
// We should rewrite `src/lib.rs` only if there are some newly compiled kernels, or removed
// some old ones
let write = !children.is_empty() || kernel_paths.len() < ptx_paths.len();
for (kernel_path, child) in children {
let output = child.expect("nvcc failed to run. Ensure that you have CUDA installed and that `nvcc` is in your PATH.");
assert!(
output.status.success(),
"nvcc error while compiling {kernel_path:?}:\n\n# stdout\n{:#}\n\n# stderr\n{:#}",
String::from_utf8_lossy(&output.stdout),
String::from_utf8_lossy(&output.stderr)
);
}
(write, kernel_paths)
}
#[allow(unused)]
fn compute_cap() -> Result<usize> {
println!("cargo:rerun-if-env-changed=CUDA_COMPUTE_CAP");
// Try to parse compute caps from env
let mut compute_cap = if let Ok(compute_cap_str) = std::env::var("CUDA_COMPUTE_CAP") {
println!("cargo:rustc-env=CUDA_COMPUTE_CAP={compute_cap_str}");
compute_cap_str
.parse::<usize>()
.context("Could not parse code")?
} else {
// Use nvidia-smi to get the current compute cap
let out = std::process::Command::new("nvidia-smi")
.arg("--query-gpu=compute_cap")
.arg("--format=csv")
.output()
.context("`nvidia-smi` failed. Ensure that you have CUDA installed and that `nvidia-smi` is in your PATH.")?;
let out = std::str::from_utf8(&out.stdout).context("stdout is not a utf8 string")?;
let mut lines = out.lines();
assert_eq!(
lines.next().context("missing line in stdout")?,
"compute_cap"
);
let cap = lines
.next()
.context("missing line in stdout")?
.replace('.', "");
let cap = cap
.parse::<usize>()
.with_context(|| format!("cannot parse as int {cap}"))?;
println!("cargo:rustc-env=CUDA_COMPUTE_CAP={cap}");
cap
};
// Grab available GPU codes from nvcc and select the highest one
let (supported_nvcc_codes, max_nvcc_code) = {
let out = std::process::Command::new("nvcc")
.arg("--list-gpu-code")
.output()
.expect("`nvcc` failed. Ensure that you have CUDA installed and that `nvcc` is in your PATH.");
let out = std::str::from_utf8(&out.stdout).unwrap();
let out = out.lines().collect::<Vec<&str>>();
let mut codes = Vec::with_capacity(out.len());
for code in out {
let code = code.split('_').collect::<Vec<&str>>();
if !code.is_empty() && code.contains(&"sm") {
if let Ok(num) = code[1].parse::<usize>() {
codes.push(num);
}
}
}
codes.sort();
let max_nvcc_code = *codes.last().context("no gpu codes parsed from nvcc")?;
(codes, max_nvcc_code)
};
// Check that nvcc supports the asked compute caps
if !supported_nvcc_codes.contains(&compute_cap) {
anyhow::bail!(
"nvcc cannot target gpu arch {compute_cap}. Available nvcc targets are {supported_nvcc_codes:?}."
);
}
if compute_cap > max_nvcc_code {
anyhow::bail!(
"CUDA compute cap {compute_cap} is higher than the highest gpu code from nvcc {max_nvcc_code}"
);
}
Ok(compute_cap)
}
} }

View File

@ -11,7 +11,7 @@ readme = "README.md"
[dependencies] [dependencies]
accelerate-src = { workspace = true, optional = true } accelerate-src = { workspace = true, optional = true }
candle = { path = "../candle-core", version = "0.3.3", package = "candle-core" } candle = { workspace = true }
half = { workspace = true } half = { workspace = true }
thiserror = { workspace = true } thiserror = { workspace = true }
intel-mkl-src = { workspace = true, optional = true } intel-mkl-src = { workspace = true, optional = true }
@ -20,7 +20,7 @@ rayon = { workspace = true }
safetensors = { workspace = true } safetensors = { workspace = true }
serde = { workspace = true } serde = { workspace = true }
metal = { workspace = true, optional = true } metal = { workspace = true, optional = true }
candle-metal-kernels = { path = "../candle-metal-kernels", version = "0.3.0", optional = true } candle-metal-kernels = { workspace = true, optional = true }
[dev-dependencies] [dev-dependencies]
anyhow = { workspace = true } anyhow = { workspace = true }

View File

@ -10,8 +10,8 @@ categories = ["science"]
license = "MIT OR Apache-2.0" license = "MIT OR Apache-2.0"
[dependencies] [dependencies]
candle = { path = "../candle-core", version = "0.3.3", package = "candle-core" } candle = { path = "../candle-core", package = "candle-core" }
candle-nn = { path = "../candle-nn", version = "0.3.3" } candle-nn = { path = "../candle-nn" }
prost = "0.12.1" prost = "0.12.1"
[build-dependencies] [build-dependencies]
@ -20,4 +20,3 @@ prost-build = "0.12.1"
[dev-dependencies] [dev-dependencies]
anyhow = { version = "1", features = ["backtrace"] } anyhow = { version = "1", features = ["backtrace"] }
clap = { version = "4.2.4", features = ["derive"] } clap = { version = "4.2.4", features = ["derive"] }

View File

@ -15,9 +15,9 @@ crate-type = ["cdylib"]
[dependencies] [dependencies]
accelerate-src = { workspace = true, optional = true } accelerate-src = { workspace = true, optional = true }
candle = { path = "../candle-core", version = "0.3.3", package = "candle-core" } candle = { workspace = true }
candle-nn = { path = "../candle-nn", version = "0.3.3" } candle-nn = { workspace = true }
candle-onnx = {path= "../candle-onnx", version = "0.3.3", optional = true} candle-onnx = { workspace = true, optional = true }
half = { workspace = true } half = { workspace = true }
intel-mkl-src = { workspace = true, optional = true } intel-mkl-src = { workspace = true, optional = true }
pyo3 = { version = "0.20.0", features = ["extension-module", "abi3-py38"] } pyo3 = { version = "0.20.0", features = ["extension-module", "abi3-py38"] }

View File

@ -12,9 +12,9 @@ readme = "README.md"
[dependencies] [dependencies]
accelerate-src = { workspace = true, optional = true } accelerate-src = { workspace = true, optional = true }
byteorder = { workspace = true } byteorder = { workspace = true }
candle = { path = "../candle-core", version = "0.3.3", package = "candle-core" } candle = { workspace = true }
candle-flash-attn = { path = "../candle-flash-attn", version = "0.3.3", optional = true } candle-flash-attn = { workspace = true, optional = true }
candle-nn = { path = "../candle-nn", version = "0.3.3" } candle-nn = { workspace = true }
intel-mkl-src = { workspace = true, optional = true } intel-mkl-src = { workspace = true, optional = true }
num-traits = { workspace = true } num-traits = { workspace = true }
rand = { workspace = true } rand = { workspace = true }

View File

@ -9,9 +9,9 @@ categories.workspace = true
license.workspace = true license.workspace = true
[dependencies] [dependencies]
candle = { path = "../../candle-core", version = "0.3.3", package = "candle-core" } candle = { workspace = true }
candle-nn = { path = "../../candle-nn", version = "0.3.3" } candle-nn = { workspace = true }
candle-transformers = { path = "../../candle-transformers", version = "0.3.3" } candle-transformers = { workspace = true }
num-traits = { workspace = true } num-traits = { workspace = true }
tokenizers = { workspace = true, features = ["unstable_wasm"] } tokenizers = { workspace = true, features = ["unstable_wasm"] }

View File

@ -9,9 +9,9 @@ categories.workspace = true
license.workspace = true license.workspace = true
[dependencies] [dependencies]
candle = { path = "../../candle-core", version = "0.3.3", package = "candle-core" } candle = { workspace = true }
candle-nn = { path = "../../candle-nn", version = "0.3.3" } candle-nn = { workspace = true }
candle-transformers = { path = "../../candle-transformers", version = "0.3.3" } candle-transformers = { workspace = true }
tokenizers = { workspace = true, features = ["unstable_wasm"] } tokenizers = { workspace = true, features = ["unstable_wasm"] }
num-traits = { workspace = true } num-traits = { workspace = true }

View File

@ -9,9 +9,9 @@ categories.workspace = true
license.workspace = true license.workspace = true
[dependencies] [dependencies]
candle = { path = "../../candle-core", version = "0.3.3", package = "candle-core" } candle = { workspace = true }
candle-nn = { path = "../../candle-nn", version = "0.3.3" } candle-nn = { workspace = true }
candle-transformers = { path = "../../candle-transformers", version = "0.3.3" } candle-transformers = { workspace = true }
num-traits = { workspace = true } num-traits = { workspace = true }
tokenizers = { workspace = true, features = ["unstable_wasm"] } tokenizers = { workspace = true, features = ["unstable_wasm"] }

View File

@ -9,9 +9,9 @@ categories.workspace = true
license.workspace = true license.workspace = true
[dependencies] [dependencies]
candle = { path = "../../candle-core", version = "0.3.3", package = "candle-core" } candle = { workspace = true }
candle-nn = { path = "../../candle-nn", version = "0.3.3" } candle-nn = { workspace = true }
candle-transformers = { path = "../../candle-transformers", version = "0.3.3" } candle-transformers = { workspace = true }
tokenizers = { workspace = true, features = ["unstable_wasm"] } tokenizers = { workspace = true, features = ["unstable_wasm"] }
num-traits = { workspace = true } num-traits = { workspace = true }

View File

@ -9,9 +9,9 @@ categories.workspace = true
license.workspace = true license.workspace = true
[dependencies] [dependencies]
candle = { path = "../../candle-core", version = "0.3.3", package = "candle-core" } candle = { workspace = true }
candle-nn = { path = "../../candle-nn", version = "0.3.3" } candle-nn = { workspace = true }
candle-transformers = { path = "../../candle-transformers", version = "0.3.3" } candle-transformers = { workspace = true }
num-traits = { workspace = true } num-traits = { workspace = true }
# App crates. # App crates.

View File

@ -9,9 +9,9 @@ categories.workspace = true
license.workspace = true license.workspace = true
[dependencies] [dependencies]
candle = { path = "../../candle-core", version = "0.3.3", package = "candle-core" } candle = { workspace = true }
candle-nn = { path = "../../candle-nn", version = "0.3.3" } candle-nn = { workspace = true }
candle-transformers = { path = "../../candle-transformers", version = "0.3.3" } candle-transformers = { workspace = true }
num-traits = { workspace = true } num-traits = { workspace = true }
tokenizers = { workspace = true, features = ["unstable_wasm"] } tokenizers = { workspace = true, features = ["unstable_wasm"] }

View File

@ -9,9 +9,9 @@ categories.workspace = true
license.workspace = true license.workspace = true
[dependencies] [dependencies]
candle = { path = "../../candle-core", version = "0.3.3", package = "candle-core" } candle = { workspace = true }
candle-nn = { path = "../../candle-nn", version = "0.3.3" } candle-nn = { workspace = true }
candle-transformers = { path = "../../candle-transformers", version = "0.3.3" } candle-transformers = { workspace = true }
num-traits = { workspace = true } num-traits = { workspace = true }
tokenizers = { workspace = true, features = ["unstable_wasm"] } tokenizers = { workspace = true, features = ["unstable_wasm"] }

View File

@ -9,8 +9,8 @@ categories.workspace = true
license.workspace = true license.workspace = true
[dependencies] [dependencies]
candle = { path = "../../candle-core", version = "0.3.3", package = "candle-core" } candle = { workspace = true }
candle-nn = { path = "../../candle-nn", version = "0.3.3" } candle-nn = { workspace = true }
num-traits = { workspace = true } num-traits = { workspace = true }
serde = { workspace = true } serde = { workspace = true }
serde_json = { workspace = true } serde_json = { workspace = true }

View File

@ -7,7 +7,7 @@ keywords.workspace = true
categories.workspace = true categories.workspace = true
[dependencies] [dependencies]
candle = { path = "../candle-core", version = "0.3.3", package = "candle-core" } candle = { workspace = true }
rand = { workspace = true } rand = { workspace = true }
getrandom = { version = "0.2", features = ["js"] } getrandom = { version = "0.2", features = ["js"] }