mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 19:18:50 +00:00
All tests are panicking instead of random failure.
This commit is contained in:
@ -172,8 +172,8 @@ impl BackendStorage for MetalStorage {
|
|||||||
.unwrap();
|
.unwrap();
|
||||||
} else {
|
} else {
|
||||||
let name = match self.dtype {
|
let name = match self.dtype {
|
||||||
DType::F32 => "affine_float",
|
DType::F32 => "affine_float_strided",
|
||||||
DType::F16 => "affine_half",
|
DType::F16 => "affine_half_strided",
|
||||||
dtype => todo!("Affine {dtype:?}"),
|
dtype => todo!("Affine {dtype:?}"),
|
||||||
};
|
};
|
||||||
candle_metal_kernels::call_affine_strided(
|
candle_metal_kernels::call_affine_strided(
|
||||||
@ -829,7 +829,6 @@ impl BackendStorage for MetalStorage {
|
|||||||
fn copy_strided_src(&self, dst: &mut Self, dst_offset: usize, src_l: &Layout) -> Result<()> {
|
fn copy_strided_src(&self, dst: &mut Self, dst_offset: usize, src_l: &Layout) -> Result<()> {
|
||||||
let src_shape = src_l.shape();
|
let src_shape = src_l.shape();
|
||||||
let el_count = src_shape.elem_count();
|
let el_count = src_shape.elem_count();
|
||||||
// todo!("COPY STRIDED {src_shape:?} {el_count} {src_l:?} {dst_offset}");
|
|
||||||
if el_count == 0 {
|
if el_count == 0 {
|
||||||
return Ok(());
|
return Ok(());
|
||||||
}
|
}
|
||||||
@ -851,7 +850,7 @@ impl BackendStorage for MetalStorage {
|
|||||||
&src_l.stride(),
|
&src_l.stride(),
|
||||||
src_l.start_offset() * self.dtype.size_in_bytes(),
|
src_l.start_offset() * self.dtype.size_in_bytes(),
|
||||||
&mut dst.buffer,
|
&mut dst.buffer,
|
||||||
dst_offset,
|
dst_offset * dst.dtype.size_in_bytes(),
|
||||||
)
|
)
|
||||||
.map_err(MetalError::from)?;
|
.map_err(MetalError::from)?;
|
||||||
// command_buffer.commit();
|
// command_buffer.commit();
|
||||||
|
@ -16,16 +16,16 @@ kernel void NAME( \
|
|||||||
if (gid >= dst_size) { \
|
if (gid >= dst_size) { \
|
||||||
return; \
|
return; \
|
||||||
} \
|
} \
|
||||||
const size_t id_i = gid / right_size / left_size; \
|
const size_t id_i = (gid / right_size) % ids_size; \
|
||||||
|
const INDEX_TYPENAME input_i = min(input_ids[id_i], (INDEX_TYPENAME)(src_dim_size - 1)); \
|
||||||
const size_t right_rank_i = gid % right_size; \
|
const size_t right_rank_i = gid % right_size; \
|
||||||
const size_t left_rank_i = gid % left_size; \
|
const size_t left_rank_i = gid / right_size / ids_size; \
|
||||||
/* \
|
/* \
|
||||||
// Force prevent out of bounds indexing \
|
// Force prevent out of bounds indexing \
|
||||||
// since there doesn't seem to be a good way to force crash \
|
// since there doesn't seem to be a good way to force crash \
|
||||||
// No need to check for zero we're only allowing unsized. \
|
// No need to check for zero we're only allowing unsized. \
|
||||||
*/ \
|
*/ \
|
||||||
const INDEX_TYPENAME input_i = min(input_ids[id_i], (INDEX_TYPENAME)(src_dim_size - 1)); \
|
const size_t src_i = left_rank_i * src_dim_size * right_size + input_i * right_size + right_rank_i; \
|
||||||
const size_t src_i = ((input_i * right_size) + right_rank_i) * left_size + left_rank_i; \
|
|
||||||
output[gid] = input[src_i]; \
|
output[gid] = input[src_i]; \
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -75,7 +75,6 @@ kernel void FN_NAME( \
|
|||||||
|
|
||||||
|
|
||||||
INDEX_OP(is_u32_f32, uint, float)
|
INDEX_OP(is_u32_f32, uint, float)
|
||||||
INDEX_OP(is_u32_f16, uint, half)
|
|
||||||
|
|
||||||
|
|
||||||
#if __METAL_VERSION__ >= 310
|
#if __METAL_VERSION__ >= 310
|
||||||
|
@ -112,7 +112,7 @@ macro_rules! ops{
|
|||||||
($($name:ident),+) => {
|
($($name:ident),+) => {
|
||||||
|
|
||||||
pub mod contiguous {
|
pub mod contiguous {
|
||||||
pub struct Kernel(pub(crate) &'static str);
|
pub struct Kernel(pub &'static str);
|
||||||
$(
|
$(
|
||||||
pub mod $name {
|
pub mod $name {
|
||||||
use super::Kernel;
|
use super::Kernel;
|
||||||
@ -131,7 +131,7 @@ macro_rules! ops{
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub mod strided {
|
pub mod strided {
|
||||||
pub struct Kernel(pub(crate) &'static str);
|
pub struct Kernel(pub &'static str);
|
||||||
$(
|
$(
|
||||||
pub mod $name {
|
pub mod $name {
|
||||||
use super::Kernel;
|
use super::Kernel;
|
||||||
@ -172,7 +172,7 @@ pub enum MetalKernelError {
|
|||||||
LockError(String),
|
LockError(String),
|
||||||
#[error("Error while loading library: {0}")]
|
#[error("Error while loading library: {0}")]
|
||||||
LoadLibraryError(String),
|
LoadLibraryError(String),
|
||||||
#[error("Error while loading function: {0}")]
|
#[error("Error while loading function: {0:?}")]
|
||||||
LoadFunctionError(String),
|
LoadFunctionError(String),
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1053,6 +1053,7 @@ mod tests {
|
|||||||
&device,
|
&device,
|
||||||
command_buffer,
|
command_buffer,
|
||||||
&kernels,
|
&kernels,
|
||||||
|
"affine_float",
|
||||||
size,
|
size,
|
||||||
&input,
|
&input,
|
||||||
&mut output,
|
&mut output,
|
||||||
@ -1087,7 +1088,7 @@ mod tests {
|
|||||||
&device,
|
&device,
|
||||||
command_buffer,
|
command_buffer,
|
||||||
&kernels,
|
&kernels,
|
||||||
size,
|
"affine_float",
|
||||||
shape,
|
shape,
|
||||||
&input,
|
&input,
|
||||||
strides,
|
strides,
|
||||||
@ -1146,7 +1147,10 @@ mod tests {
|
|||||||
result,
|
result,
|
||||||
vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 1.0f32, 2.0, 3.0, 4.0, 5.0]
|
vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 1.0f32, 2.0, 3.0, 4.0, 5.0]
|
||||||
);
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn index_select_dim1() {
|
||||||
let embedding = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0];
|
let embedding = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0];
|
||||||
let shape = [5, 2];
|
let shape = [5, 2];
|
||||||
let ids = [0u32, 1, 0];
|
let ids = [0u32, 1, 0];
|
||||||
@ -1154,7 +1158,7 @@ mod tests {
|
|||||||
let result = run_index_select(&embedding, &shape, &ids, dim);
|
let result = run_index_select(&embedding, &shape, &ids, dim);
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
result,
|
result,
|
||||||
vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 1.0f32, 2.0, 3.0, 4.0, 5.0]
|
vec![1.0f32, 2.0, 1.0, 3.0, 4.0, 3.0, 5.0, 6.0, 5.0, 7.0, 8.0f32, 7.0, 9.0, 10.0, 9.0]
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -50,6 +50,7 @@ fn run_affine_bench<T: Clone>(device: &Device, kernels: &Kernels, v: &[T]) {
|
|||||||
&device,
|
&device,
|
||||||
command_buffer,
|
command_buffer,
|
||||||
&kernels,
|
&kernels,
|
||||||
|
"affine_float",
|
||||||
v.len(),
|
v.len(),
|
||||||
&input,
|
&input,
|
||||||
&mut output,
|
&mut output,
|
@ -147,7 +147,7 @@ fn run_unary_bench<T: Clone>(
|
|||||||
println!(
|
println!(
|
||||||
"{0: <5} | {1: <19} | {2: <6} | {3: <5} | {4: <11?} | {5: <11?}",
|
"{0: <5} | {1: <19} | {2: <6} | {3: <5} | {4: <11?} | {5: <11?}",
|
||||||
type_name::<T>().split("::").last().unwrap(),
|
type_name::<T>().split("::").last().unwrap(),
|
||||||
kernel_name.to_string(),
|
kernel_name.0,
|
||||||
v.len(),
|
v.len(),
|
||||||
iterations,
|
iterations,
|
||||||
total_time,
|
total_time,
|
||||||
@ -159,7 +159,7 @@ fn run_unary_bench<T: Clone>(
|
|||||||
let shape = vec![2, 5_000];
|
let shape = vec![2, 5_000];
|
||||||
let strides = vec![2, 1];
|
let strides = vec![2, 1];
|
||||||
let offset = 0;
|
let offset = 0;
|
||||||
for kernel_name in strided {
|
for kernel_name in &strided {
|
||||||
let total_time = autoreleasepool(|| {
|
let total_time = autoreleasepool(|| {
|
||||||
let command_buffer = command_queue.new_command_buffer();
|
let command_buffer = command_queue.new_command_buffer();
|
||||||
let start = Instant::now();
|
let start = Instant::now();
|
||||||
@ -187,7 +187,7 @@ fn run_unary_bench<T: Clone>(
|
|||||||
println!(
|
println!(
|
||||||
"{0: <5} | {1: <19} | {2: <6} | {3: <5} | {4: <11?} | {5: <11?}",
|
"{0: <5} | {1: <19} | {2: <6} | {3: <5} | {4: <11?} | {5: <11?}",
|
||||||
type_name::<T>().split("::").last().unwrap(),
|
type_name::<T>().split("::").last().unwrap(),
|
||||||
kernel_name.to_string(),
|
kernel_name.0,
|
||||||
v.len(),
|
v.len(),
|
||||||
iterations,
|
iterations,
|
||||||
total_time,
|
total_time,
|
Reference in New Issue
Block a user