mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Compare commits
1 Commits
0.7.1
...
ivarflakst
Author | SHA1 | Date | |
---|---|---|---|
933716b374 |
@ -822,10 +822,13 @@ impl BackendStorage for MetalStorage {
|
|||||||
layout.stride(),
|
layout.stride(),
|
||||||
layout.start_offset() * self.dtype.size_in_bytes(),
|
layout.start_offset() * self.dtype.size_in_bytes(),
|
||||||
),
|
),
|
||||||
|
!layout.is_contiguous(),
|
||||||
&t.buffer,
|
&t.buffer,
|
||||||
(&t_l.stride(), t_l.start_offset() * t.dtype.size_in_bytes()),
|
(&t_l.stride(), t_l.start_offset() * t.dtype.size_in_bytes()),
|
||||||
|
!t_l.is_contiguous(),
|
||||||
&f.buffer,
|
&f.buffer,
|
||||||
(&f_l.stride(), f_l.start_offset() * f.dtype.size_in_bytes()),
|
(&f_l.stride(), f_l.start_offset() * f.dtype.size_in_bytes()),
|
||||||
|
!f_l.is_contiguous(),
|
||||||
&buffer,
|
&buffer,
|
||||||
)
|
)
|
||||||
.map_err(MetalError::from)?;
|
.map_err(MetalError::from)?;
|
||||||
|
@ -909,13 +909,22 @@ pub fn call_where_cond_strided(
|
|||||||
shape: &[usize],
|
shape: &[usize],
|
||||||
cond: &Buffer,
|
cond: &Buffer,
|
||||||
(cond_stride, cond_offset): (&[usize], usize),
|
(cond_stride, cond_offset): (&[usize], usize),
|
||||||
|
cond_is_strided: bool,
|
||||||
left: &Buffer,
|
left: &Buffer,
|
||||||
(left_stride, left_offset): (&[usize], usize),
|
(left_stride, left_offset): (&[usize], usize),
|
||||||
|
left_is_strided: bool,
|
||||||
right: &Buffer,
|
right: &Buffer,
|
||||||
(right_stride, right_offset): (&[usize], usize),
|
(right_stride, right_offset): (&[usize], usize),
|
||||||
|
right_is_strided: bool,
|
||||||
output: &Buffer,
|
output: &Buffer,
|
||||||
) -> Result<(), MetalKernelError> {
|
) -> Result<(), MetalKernelError> {
|
||||||
let pipeline = kernels.load_pipeline(device, Source::Ternary, name)?;
|
let constants = Some(ConstantValues::new(vec![
|
||||||
|
(0, Value::Bool(cond_is_strided)),
|
||||||
|
(1, Value::Bool(left_is_strided)),
|
||||||
|
(2, Value::Bool(right_is_strided)),
|
||||||
|
]));
|
||||||
|
let pipeline =
|
||||||
|
kernels.load_pipeline_with_constants(device, Source::Ternary, name, constants)?;
|
||||||
|
|
||||||
let encoder = command_buffer.new_compute_command_encoder();
|
let encoder = command_buffer.new_compute_command_encoder();
|
||||||
encoder.set_compute_pipeline_state(&pipeline);
|
encoder.set_compute_pipeline_state(&pipeline);
|
||||||
|
@ -1,14 +1,20 @@
|
|||||||
#include <metal_stdlib>
|
#include <metal_stdlib>
|
||||||
#
|
|
||||||
using namespace metal;
|
using namespace metal;
|
||||||
|
|
||||||
|
constant bool IDS_STRIDED [[function_constant(0)]];
|
||||||
|
constant bool T_STRIDED [[function_constant(1)]];
|
||||||
|
constant bool F_STRIDED [[function_constant(2)]];
|
||||||
|
|
||||||
|
|
||||||
METAL_FUNC uint get_strided_index(
|
METAL_FUNC uint get_strided_index(
|
||||||
uint idx,
|
uint idx,
|
||||||
constant size_t &num_dims,
|
constant const size_t &num_dims,
|
||||||
constant size_t *dims,
|
constant const size_t *dims,
|
||||||
constant size_t *strides
|
constant const size_t *strides
|
||||||
) {
|
) {
|
||||||
uint strided_i = 0;
|
uint strided_i = 0;
|
||||||
|
#pragma clang loop unroll(full)
|
||||||
for (uint d = 0; d < num_dims; d++) {
|
for (uint d = 0; d < num_dims; d++) {
|
||||||
uint dim_idx = num_dims - 1 - d;
|
uint dim_idx = num_dims - 1 - d;
|
||||||
strided_i += (idx % dims[dim_idx]) * strides[dim_idx];
|
strided_i += (idx % dims[dim_idx]) * strides[dim_idx];
|
||||||
@ -17,6 +23,7 @@ METAL_FUNC uint get_strided_index(
|
|||||||
return strided_i;
|
return strided_i;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
template<typename T, typename ID>
|
template<typename T, typename ID>
|
||||||
METAL_FUNC void where_cond(
|
METAL_FUNC void where_cond(
|
||||||
constant size_t &numel,
|
constant size_t &numel,
|
||||||
@ -34,10 +41,20 @@ METAL_FUNC void where_cond(
|
|||||||
if (i >= numel){
|
if (i >= numel){
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
uint strided_i = get_strided_index(i, num_dims, dims, strides);
|
uint strided_i = i;
|
||||||
uint strided_i_t = get_strided_index(i, num_dims, dims, strides_t);
|
uint strided_i_t = i;
|
||||||
uint strided_i_f = get_strided_index(i, num_dims, dims, strides_f);
|
uint strided_i_f = i;
|
||||||
out[i] = ids[strided_i] ? t[strided_i_t] : f[strided_i_f];
|
if (IDS_STRIDED) {
|
||||||
|
strided_i = get_strided_index(i, num_dims, dims, strides);
|
||||||
|
}
|
||||||
|
if (T_STRIDED) {
|
||||||
|
strided_i_t = get_strided_index(i, num_dims, dims, strides_t);
|
||||||
|
}
|
||||||
|
if (F_STRIDED) {
|
||||||
|
strided_i_f = get_strided_index(i, num_dims, dims, strides_f);
|
||||||
|
}
|
||||||
|
|
||||||
|
out[i] = select(f[strided_i_t], t[strided_i_f], ids[strided_i]);
|
||||||
}
|
}
|
||||||
|
|
||||||
#define WHERE_OP(T, ID, FN_NAME) \
|
#define WHERE_OP(T, ID, FN_NAME) \
|
||||||
|
@ -803,10 +803,13 @@ fn run_where_cond<I: Clone, T: Clone>(
|
|||||||
shape,
|
shape,
|
||||||
&cond,
|
&cond,
|
||||||
(&cond_stride, cond_offset),
|
(&cond_stride, cond_offset),
|
||||||
|
true,
|
||||||
&left,
|
&left,
|
||||||
(&left_stride, left_offset),
|
(&left_stride, left_offset),
|
||||||
|
true,
|
||||||
&right,
|
&right,
|
||||||
(&cond_stride, cond_offset),
|
(&cond_stride, cond_offset),
|
||||||
|
true,
|
||||||
&output,
|
&output,
|
||||||
)
|
)
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
Reference in New Issue
Block a user