Fixed all bugs. Improved code quality. Added tests.

This commit is contained in:
Ivar Flakstad
2024-01-30 14:12:57 +01:00
parent 077e781f53
commit 8babfe0411
5 changed files with 1062 additions and 561 deletions

View File

@ -61,13 +61,21 @@ fn criterion_benchmark(c: &mut Criterion) {
run_softmax(c, &device, (f16::from_f32(lo), f16::from_f32(up)));
run_softmax(c, &device, (bf16::from_f32(lo), bf16::from_f32(up)));
run_reduce(c, &device, (lo, up));
run_reduce(c, &device, (f16::from_f32(lo), f16::from_f32(up)));
run_reduce(c, &device, (bf16::from_f32(lo), bf16::from_f32(up)));
run_reduce(c, &device, (lo, up), false);
run_reduce(c, &device, (f16::from_f32(lo), f16::from_f32(up)), false);
run_reduce(c, &device, (bf16::from_f32(lo), bf16::from_f32(up)), false);
run_arg_reduce(c, &device, (lo, up));
run_arg_reduce(c, &device, (f16::from_f32(lo), f16::from_f32(up)));
run_arg_reduce(c, &device, (bf16::from_f32(lo), bf16::from_f32(up)));
run_arg_reduce(c, &device, (lo, up), false);
run_arg_reduce(c, &device, (f16::from_f32(lo), f16::from_f32(up)), false);
run_arg_reduce(c, &device, (bf16::from_f32(lo), bf16::from_f32(up)), false);
run_reduce(c, &device, (lo, up), true);
run_reduce(c, &device, (f16::from_f32(lo), f16::from_f32(up)), true);
run_reduce(c, &device, (bf16::from_f32(lo), bf16::from_f32(up)), true);
run_arg_reduce(c, &device, (lo, up), true);
run_arg_reduce(c, &device, (f16::from_f32(lo), f16::from_f32(up)), true);
run_arg_reduce(c, &device, (bf16::from_f32(lo), bf16::from_f32(up)), true);
}
}
@ -89,6 +97,7 @@ fn run_softmax<T: candle_core::FloatDType>(c: &mut Criterion, device: &Device, (
DType::BF16 => "softmax_bf16",
_ => "softmax",
};
softmax(&a).unwrap();
let mut group = c.benchmark_group(device.bench_name(name));
group.throughput(Throughput::Bytes(flops as u64));
@ -105,19 +114,49 @@ fn run_softmax<T: candle_core::FloatDType>(c: &mut Criterion, device: &Device, (
group.finish();
}
fn run_reduce<T: candle_core::FloatDType>(c: &mut Criterion, device: &Device, (lo, up): (T, T)) {
fn run_reduce<T: candle_core::FloatDType>(
c: &mut Criterion,
device: &Device,
(lo, up): (T, T),
strided: bool,
) {
let b = 1;
let m = 1024;
let k = 1024;
let a = Tensor::rand(lo, up, (b, m, k), &device).unwrap();
let a = if strided {
Tensor::rand(lo, up, (b, m, k), &device)
.unwrap()
.transpose(0, 2)
.unwrap()
} else {
Tensor::rand(lo, up, (b, m, k), &device).unwrap()
};
let flops = b * m * k * T::DTYPE.size_in_bytes();
let name = match T::DTYPE {
DType::F32 => "reduce_f32",
DType::F16 => "reduce_f16",
DType::BF16 => "reduce_bf16",
DType::F32 => {
if strided {
"reduce_f32_strided"
} else {
"reduce_f32"
}
}
DType::F16 => {
if strided {
"reduce_f16_strided"
} else {
"reduce_f16"
}
}
DType::BF16 => {
if strided {
"reduce_bf16_strided"
} else {
"reduce_bf16"
}
}
_ => "reduce",
};
@ -140,20 +179,46 @@ fn run_arg_reduce<T: candle_core::FloatDType>(
c: &mut Criterion,
device: &Device,
(lo, up): (T, T),
strided: bool,
) {
let b = 1;
let m = 1024;
let k = 1024;
let a = Tensor::rand(lo, up, (b, m, k), &device).unwrap();
let a = if strided {
Tensor::rand(lo, up, (b, m, k), &device)
.unwrap()
.transpose(0, 2)
.unwrap()
} else {
Tensor::rand(lo, up, (b, m, k), &device).unwrap()
};
let flops = b * m * k * T::DTYPE.size_in_bytes();
let flops = b * m * k * (DType::U32.size_in_bytes() + T::DTYPE.size_in_bytes());
let name = match T::DTYPE {
DType::F32 => "arg_reduce_f32",
DType::F16 => "arg_reduce_f16",
DType::BF16 => "arg_reduce_bf16",
_ => "reduce",
DType::F32 => {
if strided {
"arg_reduce_f32_strided"
} else {
"arg_reduce_f32"
}
}
DType::F16 => {
if strided {
"arg_reduce_f16_strided"
} else {
"arg_reduce_f16"
}
}
DType::BF16 => {
if strided {
"arg_reduce_bf16_strided"
} else {
"arg_reduce_bf16"
}
}
_ => "unknown",
};
let mut group = c.benchmark_group(device.bench_name(name));