mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Process unary functions per block (#180)
* Process unary functions per block. * Add some inline hints.
This commit is contained in:
@ -98,6 +98,7 @@ struct Sum<'a> {
|
||||
}
|
||||
|
||||
impl<'a> Map1 for Sum<'a> {
|
||||
#[inline(always)]
|
||||
fn f<T: WithDType>(&self, src: &[T], src_layout: &Layout) -> Result<Vec<T>> {
|
||||
let mut dst = vec![T::zero(); self.dst_shape.elem_count()];
|
||||
for (unstr_index, src_index) in src_layout.strided_index().enumerate() {
|
||||
@ -115,10 +116,35 @@ impl<'a> Map1 for Sum<'a> {
|
||||
}
|
||||
|
||||
fn unary_map<T: Copy, U: Copy, F: FnMut(T) -> U>(vs: &[T], layout: &Layout, mut f: F) -> Vec<U> {
|
||||
match layout.contiguous_offsets() {
|
||||
Some((o1, o2)) => vs[o1..o2].iter().map(|&v| f(v)).collect(),
|
||||
None => layout.strided_index().map(|i| f(vs[i])).collect(),
|
||||
let mut result = vec![];
|
||||
result.reserve(layout.shape().elem_count());
|
||||
match layout.strided_blocks() {
|
||||
crate::StridedBlocks::SingleBlock { start_offset, len } => {
|
||||
for &v in vs[start_offset..start_offset + len].iter() {
|
||||
result.push(f(v))
|
||||
}
|
||||
}
|
||||
crate::StridedBlocks::MultipleBlocks {
|
||||
block_start_index,
|
||||
block_len,
|
||||
} => {
|
||||
// Specialize the case where block_len is one to avoid the second loop.
|
||||
if block_len == 1 {
|
||||
for index in block_start_index {
|
||||
let v = unsafe { vs.get_unchecked(index) };
|
||||
result.push(f(*v))
|
||||
}
|
||||
} else {
|
||||
for index in block_start_index {
|
||||
for offset in 0..block_len {
|
||||
let v = unsafe { vs.get_unchecked(index + offset) };
|
||||
result.push(f(*v))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
result
|
||||
}
|
||||
|
||||
// This function maps over two strided index sequences.
|
||||
|
@ -95,21 +95,27 @@ macro_rules! bin_op {
|
||||
const NAME: &'static str = $name;
|
||||
const KERNEL: &'static str = concat!("b", $name);
|
||||
const V: Self = $op;
|
||||
#[inline(always)]
|
||||
fn bf16(v1: bf16, v2: bf16) -> bf16 {
|
||||
$e(v1, v2)
|
||||
}
|
||||
#[inline(always)]
|
||||
fn f16(v1: f16, v2: f16) -> f16 {
|
||||
$e(v1, v2)
|
||||
}
|
||||
#[inline(always)]
|
||||
fn f32(v1: f32, v2: f32) -> f32 {
|
||||
$e(v1, v2)
|
||||
}
|
||||
#[inline(always)]
|
||||
fn f64(v1: f64, v2: f64) -> f64 {
|
||||
$e(v1, v2)
|
||||
}
|
||||
#[inline(always)]
|
||||
fn u8(v1: u8, v2: u8) -> u8 {
|
||||
$e(v1, v2)
|
||||
}
|
||||
#[inline(always)]
|
||||
fn u32(v1: u32, v2: u32) -> u32 {
|
||||
$e(v1, v2)
|
||||
}
|
||||
@ -128,21 +134,27 @@ macro_rules! unary_op {
|
||||
const NAME: &'static str = $name;
|
||||
const KERNEL: &'static str = concat!("u", $name);
|
||||
const V: Self = $op;
|
||||
#[inline(always)]
|
||||
fn bf16($a: bf16) -> bf16 {
|
||||
$e
|
||||
}
|
||||
#[inline(always)]
|
||||
fn f16($a: f16) -> f16 {
|
||||
$e
|
||||
}
|
||||
#[inline(always)]
|
||||
fn f32($a: f32) -> f32 {
|
||||
$e
|
||||
}
|
||||
#[inline(always)]
|
||||
fn f64($a: f64) -> f64 {
|
||||
$e
|
||||
}
|
||||
#[inline(always)]
|
||||
fn u8(_: u8) -> u8 {
|
||||
todo!("no unary function for u8")
|
||||
}
|
||||
#[inline(always)]
|
||||
fn u32(_: u32) -> u32 {
|
||||
todo!("no unary function for u32")
|
||||
}
|
||||
@ -164,6 +176,7 @@ unary_op!(Sqrt, "sqrt", v, v.sqrt());
|
||||
impl UnaryOp for Gelu {
|
||||
const NAME: &'static str = "gelu";
|
||||
const V: Self = Gelu;
|
||||
#[inline(always)]
|
||||
fn bf16(v: bf16) -> bf16 {
|
||||
bf16::from_f32_const(0.5)
|
||||
* v
|
||||
@ -174,6 +187,7 @@ impl UnaryOp for Gelu {
|
||||
* (bf16::ONE + bf16::from_f32_const(0.044715) * v * v),
|
||||
))
|
||||
}
|
||||
#[inline(always)]
|
||||
fn f16(v: f16) -> f16 {
|
||||
f16::from_f32_const(0.5)
|
||||
* v
|
||||
@ -184,19 +198,23 @@ impl UnaryOp for Gelu {
|
||||
* (f16::ONE + f16::from_f32_const(0.044715) * v * v),
|
||||
))
|
||||
}
|
||||
#[inline(always)]
|
||||
fn f32(v: f32) -> f32 {
|
||||
0.5 * v
|
||||
* (1.0
|
||||
+ f32::tanh((2.0f32 / std::f32::consts::PI).sqrt() * v * (1.0 + 0.044715 * v * v)))
|
||||
}
|
||||
#[inline(always)]
|
||||
fn f64(v: f64) -> f64 {
|
||||
0.5 * v
|
||||
* (1.0
|
||||
+ f64::tanh((2.0f64 / std::f64::consts::PI).sqrt() * v * (1.0 + 0.044715 * v * v)))
|
||||
}
|
||||
#[inline(always)]
|
||||
fn u8(_: u8) -> u8 {
|
||||
0
|
||||
}
|
||||
#[inline(always)]
|
||||
fn u32(_: u32) -> u32 {
|
||||
0
|
||||
}
|
||||
@ -207,21 +225,27 @@ impl UnaryOp for Relu {
|
||||
const NAME: &'static str = "relu";
|
||||
const KERNEL: &'static str = "urelu";
|
||||
const V: Self = Relu;
|
||||
#[inline(always)]
|
||||
fn bf16(v: bf16) -> bf16 {
|
||||
v.max(bf16::ZERO)
|
||||
}
|
||||
#[inline(always)]
|
||||
fn f16(v: f16) -> f16 {
|
||||
v.max(f16::ZERO)
|
||||
}
|
||||
#[inline(always)]
|
||||
fn f32(v: f32) -> f32 {
|
||||
v.max(0f32)
|
||||
}
|
||||
#[inline(always)]
|
||||
fn f64(v: f64) -> f64 {
|
||||
v.max(0f64)
|
||||
}
|
||||
#[inline(always)]
|
||||
fn u8(v: u8) -> u8 {
|
||||
v
|
||||
}
|
||||
#[inline(always)]
|
||||
fn u32(v: u32) -> u32 {
|
||||
v
|
||||
}
|
||||
|
Reference in New Issue
Block a user