mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
All tests are panicking instead of random failure.
This commit is contained in:
@ -172,8 +172,8 @@ impl BackendStorage for MetalStorage {
|
||||
.unwrap();
|
||||
} else {
|
||||
let name = match self.dtype {
|
||||
DType::F32 => "affine_float",
|
||||
DType::F16 => "affine_half",
|
||||
DType::F32 => "affine_float_strided",
|
||||
DType::F16 => "affine_half_strided",
|
||||
dtype => todo!("Affine {dtype:?}"),
|
||||
};
|
||||
candle_metal_kernels::call_affine_strided(
|
||||
@ -829,7 +829,6 @@ impl BackendStorage for MetalStorage {
|
||||
fn copy_strided_src(&self, dst: &mut Self, dst_offset: usize, src_l: &Layout) -> Result<()> {
|
||||
let src_shape = src_l.shape();
|
||||
let el_count = src_shape.elem_count();
|
||||
// todo!("COPY STRIDED {src_shape:?} {el_count} {src_l:?} {dst_offset}");
|
||||
if el_count == 0 {
|
||||
return Ok(());
|
||||
}
|
||||
@ -851,7 +850,7 @@ impl BackendStorage for MetalStorage {
|
||||
&src_l.stride(),
|
||||
src_l.start_offset() * self.dtype.size_in_bytes(),
|
||||
&mut dst.buffer,
|
||||
dst_offset,
|
||||
dst_offset * dst.dtype.size_in_bytes(),
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
// command_buffer.commit();
|
||||
|
@ -16,16 +16,16 @@ kernel void NAME( \
|
||||
if (gid >= dst_size) { \
|
||||
return; \
|
||||
} \
|
||||
const size_t id_i = gid / right_size / left_size; \
|
||||
const size_t id_i = (gid / right_size) % ids_size; \
|
||||
const INDEX_TYPENAME input_i = min(input_ids[id_i], (INDEX_TYPENAME)(src_dim_size - 1)); \
|
||||
const size_t right_rank_i = gid % right_size; \
|
||||
const size_t left_rank_i = gid % left_size; \
|
||||
const size_t left_rank_i = gid / right_size / ids_size; \
|
||||
/* \
|
||||
// Force prevent out of bounds indexing \
|
||||
// since there doesn't seem to be a good way to force crash \
|
||||
// No need to check for zero we're only allowing unsized. \
|
||||
*/ \
|
||||
const INDEX_TYPENAME input_i = min(input_ids[id_i], (INDEX_TYPENAME)(src_dim_size - 1)); \
|
||||
const size_t src_i = ((input_i * right_size) + right_rank_i) * left_size + left_rank_i; \
|
||||
const size_t src_i = left_rank_i * src_dim_size * right_size + input_i * right_size + right_rank_i; \
|
||||
output[gid] = input[src_i]; \
|
||||
}
|
||||
|
||||
@ -75,7 +75,6 @@ kernel void FN_NAME( \
|
||||
|
||||
|
||||
INDEX_OP(is_u32_f32, uint, float)
|
||||
INDEX_OP(is_u32_f16, uint, half)
|
||||
|
||||
|
||||
#if __METAL_VERSION__ >= 310
|
||||
|
@ -112,7 +112,7 @@ macro_rules! ops{
|
||||
($($name:ident),+) => {
|
||||
|
||||
pub mod contiguous {
|
||||
pub struct Kernel(pub(crate) &'static str);
|
||||
pub struct Kernel(pub &'static str);
|
||||
$(
|
||||
pub mod $name {
|
||||
use super::Kernel;
|
||||
@ -131,7 +131,7 @@ macro_rules! ops{
|
||||
}
|
||||
|
||||
pub mod strided {
|
||||
pub struct Kernel(pub(crate) &'static str);
|
||||
pub struct Kernel(pub &'static str);
|
||||
$(
|
||||
pub mod $name {
|
||||
use super::Kernel;
|
||||
@ -172,7 +172,7 @@ pub enum MetalKernelError {
|
||||
LockError(String),
|
||||
#[error("Error while loading library: {0}")]
|
||||
LoadLibraryError(String),
|
||||
#[error("Error while loading function: {0}")]
|
||||
#[error("Error while loading function: {0:?}")]
|
||||
LoadFunctionError(String),
|
||||
}
|
||||
|
||||
@ -1053,6 +1053,7 @@ mod tests {
|
||||
&device,
|
||||
command_buffer,
|
||||
&kernels,
|
||||
"affine_float",
|
||||
size,
|
||||
&input,
|
||||
&mut output,
|
||||
@ -1087,7 +1088,7 @@ mod tests {
|
||||
&device,
|
||||
command_buffer,
|
||||
&kernels,
|
||||
size,
|
||||
"affine_float",
|
||||
shape,
|
||||
&input,
|
||||
strides,
|
||||
@ -1146,7 +1147,10 @@ mod tests {
|
||||
result,
|
||||
vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 1.0f32, 2.0, 3.0, 4.0, 5.0]
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn index_select_dim1() {
|
||||
let embedding = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0];
|
||||
let shape = [5, 2];
|
||||
let ids = [0u32, 1, 0];
|
||||
@ -1154,7 +1158,7 @@ mod tests {
|
||||
let result = run_index_select(&embedding, &shape, &ids, dim);
|
||||
assert_eq!(
|
||||
result,
|
||||
vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 1.0f32, 2.0, 3.0, 4.0, 5.0]
|
||||
vec![1.0f32, 2.0, 1.0, 3.0, 4.0, 3.0, 5.0, 6.0, 5.0, 7.0, 8.0f32, 7.0, 9.0, 10.0, 9.0]
|
||||
);
|
||||
}
|
||||
|
||||
|
@ -50,6 +50,7 @@ fn run_affine_bench<T: Clone>(device: &Device, kernels: &Kernels, v: &[T]) {
|
||||
&device,
|
||||
command_buffer,
|
||||
&kernels,
|
||||
"affine_float",
|
||||
v.len(),
|
||||
&input,
|
||||
&mut output,
|
@ -147,7 +147,7 @@ fn run_unary_bench<T: Clone>(
|
||||
println!(
|
||||
"{0: <5} | {1: <19} | {2: <6} | {3: <5} | {4: <11?} | {5: <11?}",
|
||||
type_name::<T>().split("::").last().unwrap(),
|
||||
kernel_name.to_string(),
|
||||
kernel_name.0,
|
||||
v.len(),
|
||||
iterations,
|
||||
total_time,
|
||||
@ -159,7 +159,7 @@ fn run_unary_bench<T: Clone>(
|
||||
let shape = vec![2, 5_000];
|
||||
let strides = vec![2, 1];
|
||||
let offset = 0;
|
||||
for kernel_name in strided {
|
||||
for kernel_name in &strided {
|
||||
let total_time = autoreleasepool(|| {
|
||||
let command_buffer = command_queue.new_command_buffer();
|
||||
let start = Instant::now();
|
||||
@ -187,7 +187,7 @@ fn run_unary_bench<T: Clone>(
|
||||
println!(
|
||||
"{0: <5} | {1: <19} | {2: <6} | {3: <5} | {4: <11?} | {5: <11?}",
|
||||
type_name::<T>().split("::").last().unwrap(),
|
||||
kernel_name.to_string(),
|
||||
kernel_name.0,
|
||||
v.len(),
|
||||
iterations,
|
||||
total_time,
|
Reference in New Issue
Block a user