All tests are panicking instead of random failure.

This commit is contained in:
Nicolas Patry
2023-11-11 17:06:35 +01:00
parent 54355ff997
commit 3900091e75
7 changed files with 20 additions and 17 deletions

View File

@ -172,8 +172,8 @@ impl BackendStorage for MetalStorage {
.unwrap();
} else {
let name = match self.dtype {
DType::F32 => "affine_float",
DType::F16 => "affine_half",
DType::F32 => "affine_float_strided",
DType::F16 => "affine_half_strided",
dtype => todo!("Affine {dtype:?}"),
};
candle_metal_kernels::call_affine_strided(
@ -829,7 +829,6 @@ impl BackendStorage for MetalStorage {
fn copy_strided_src(&self, dst: &mut Self, dst_offset: usize, src_l: &Layout) -> Result<()> {
let src_shape = src_l.shape();
let el_count = src_shape.elem_count();
// todo!("COPY STRIDED {src_shape:?} {el_count} {src_l:?} {dst_offset}");
if el_count == 0 {
return Ok(());
}
@ -851,7 +850,7 @@ impl BackendStorage for MetalStorage {
&src_l.stride(),
src_l.start_offset() * self.dtype.size_in_bytes(),
&mut dst.buffer,
dst_offset,
dst_offset * dst.dtype.size_in_bytes(),
)
.map_err(MetalError::from)?;
// command_buffer.commit();

View File

@ -16,16 +16,16 @@ kernel void NAME( \
if (gid >= dst_size) { \
return; \
} \
const size_t id_i = gid / right_size / left_size; \
const size_t id_i = (gid / right_size) % ids_size; \
const INDEX_TYPENAME input_i = min(input_ids[id_i], (INDEX_TYPENAME)(src_dim_size - 1)); \
const size_t right_rank_i = gid % right_size; \
const size_t left_rank_i = gid % left_size; \
const size_t left_rank_i = gid / right_size / ids_size; \
/* \
// Force prevent out of bounds indexing \
// since there doesn't seem to be a good way to force crash \
// No need to check for zero we're only allowing unsized. \
*/ \
const INDEX_TYPENAME input_i = min(input_ids[id_i], (INDEX_TYPENAME)(src_dim_size - 1)); \
const size_t src_i = ((input_i * right_size) + right_rank_i) * left_size + left_rank_i; \
const size_t src_i = left_rank_i * src_dim_size * right_size + input_i * right_size + right_rank_i; \
output[gid] = input[src_i]; \
}
@ -75,7 +75,6 @@ kernel void FN_NAME( \
INDEX_OP(is_u32_f32, uint, float)
INDEX_OP(is_u32_f16, uint, half)
#if __METAL_VERSION__ >= 310

View File

@ -112,7 +112,7 @@ macro_rules! ops{
($($name:ident),+) => {
pub mod contiguous {
pub struct Kernel(pub(crate) &'static str);
pub struct Kernel(pub &'static str);
$(
pub mod $name {
use super::Kernel;
@ -131,7 +131,7 @@ macro_rules! ops{
}
pub mod strided {
pub struct Kernel(pub(crate) &'static str);
pub struct Kernel(pub &'static str);
$(
pub mod $name {
use super::Kernel;
@ -172,7 +172,7 @@ pub enum MetalKernelError {
LockError(String),
#[error("Error while loading library: {0}")]
LoadLibraryError(String),
#[error("Error while loading function: {0}")]
#[error("Error while loading function: {0:?}")]
LoadFunctionError(String),
}
@ -1053,6 +1053,7 @@ mod tests {
&device,
command_buffer,
&kernels,
"affine_float",
size,
&input,
&mut output,
@ -1087,7 +1088,7 @@ mod tests {
&device,
command_buffer,
&kernels,
size,
"affine_float",
shape,
&input,
strides,
@ -1146,7 +1147,10 @@ mod tests {
result,
vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 1.0f32, 2.0, 3.0, 4.0, 5.0]
);
}
#[test]
fn index_select_dim1() {
let embedding = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0];
let shape = [5, 2];
let ids = [0u32, 1, 0];
@ -1154,7 +1158,7 @@ mod tests {
let result = run_index_select(&embedding, &shape, &ids, dim);
assert_eq!(
result,
vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 1.0f32, 2.0, 3.0, 4.0, 5.0]
vec![1.0f32, 2.0, 1.0, 3.0, 4.0, 3.0, 5.0, 6.0, 5.0, 7.0, 8.0f32, 7.0, 9.0, 10.0, 9.0]
);
}

View File

@ -50,6 +50,7 @@ fn run_affine_bench<T: Clone>(device: &Device, kernels: &Kernels, v: &[T]) {
&device,
command_buffer,
&kernels,
"affine_float",
v.len(),
&input,
&mut output,

View File

@ -147,7 +147,7 @@ fn run_unary_bench<T: Clone>(
println!(
"{0: <5} | {1: <19} | {2: <6} | {3: <5} | {4: <11?} | {5: <11?}",
type_name::<T>().split("::").last().unwrap(),
kernel_name.to_string(),
kernel_name.0,
v.len(),
iterations,
total_time,
@ -159,7 +159,7 @@ fn run_unary_bench<T: Clone>(
let shape = vec![2, 5_000];
let strides = vec![2, 1];
let offset = 0;
for kernel_name in strided {
for kernel_name in &strided {
let total_time = autoreleasepool(|| {
let command_buffer = command_queue.new_command_buffer();
let start = Instant::now();
@ -187,7 +187,7 @@ fn run_unary_bench<T: Clone>(
println!(
"{0: <5} | {1: <19} | {2: <6} | {3: <5} | {4: <11?} | {5: <11?}",
type_name::<T>().split("::").last().unwrap(),
kernel_name.to_string(),
kernel_name.0,
v.len(),
iterations,
total_time,