Add an erf based gelu op (#900)

* Erf based gelu.

* Add the erf backed gelu.

* Test the new gelu op (which is not gelu_new).
This commit is contained in:
Laurent Mazare
2023-09-19 19:54:28 +01:00
committed by GitHub
parent 34f2ecbc3b
commit d7e48234d4
8 changed files with 851 additions and 1 deletions

View File

@ -442,6 +442,9 @@ impl Tensor {
*sum_grad = sum_grad.add(&arg_grad)?
}
Op::Unary(_, UnaryOp::Gelu) => Err(Error::BackwardNotSupported { op: "gelu" })?,
Op::Unary(_, UnaryOp::GeluErf) => {
Err(Error::BackwardNotSupported { op: "gelu-erf" })?
}
Op::Unary(arg, UnaryOp::Relu) => {
let sum_grad = grads.or_insert(arg)?;
let relu_grad = arg.ge(&arg.zeros_like()?)?.to_dtype(arg.dtype())?;

763
candle-core/src/cpu/erf.rs Normal file
View File

@ -0,0 +1,763 @@
#![allow(clippy::excessive_precision)]
// Code taken from https://github.com/statrs-dev/statrs
//! Provides the [error](https://en.wikipedia.org/wiki/Error_function) and
//! related functions
mod evaluate {
//! Provides functions that don't have a numerical solution and must
//! be solved computationally (e.g. evaluation of a polynomial)
/// evaluates a polynomial at `z` where `coeff` are the coeffecients
/// to a polynomial of order `k` where `k` is the length of `coeff` and the
/// coeffecient
/// to the `k`th power is the `k`th element in coeff. E.g. [3,-1,2] equates to
/// `2z^2 - z + 3`
///
/// # Remarks
///
/// Returns 0 for a 0 length coefficient slice
pub fn polynomial(z: f64, coeff: &[f64]) -> f64 {
let n = coeff.len();
if n == 0 {
return 0.0;
}
let mut sum = *coeff.last().unwrap();
for c in coeff[0..n - 1].iter().rev() {
sum = *c + z * sum;
}
sum
}
}
use std::f64;
/// `erf` calculates the error function at `x`.
pub fn erf(x: f64) -> f64 {
if x.is_nan() {
f64::NAN
} else if x >= 0.0 && x.is_infinite() {
1.0
} else if x <= 0.0 && x.is_infinite() {
-1.0
} else if x == 0. {
0.0
} else {
erf_impl(x, false)
}
}
/// `erf_inv` calculates the inverse error function
/// at `x`.
pub fn erf_inv(x: f64) -> f64 {
if x == 0.0 {
0.0
} else if x >= 1.0 {
f64::INFINITY
} else if x <= -1.0 {
f64::NEG_INFINITY
} else if x < 0.0 {
erf_inv_impl(-x, 1.0 + x, -1.0)
} else {
erf_inv_impl(x, 1.0 - x, 1.0)
}
}
/// `erfc` calculates the complementary error function
/// at `x`.
pub fn erfc(x: f64) -> f64 {
if x.is_nan() {
f64::NAN
} else if x == f64::INFINITY {
0.0
} else if x == f64::NEG_INFINITY {
2.0
} else {
erf_impl(x, true)
}
}
/// `erfc_inv` calculates the complementary inverse
/// error function at `x`.
pub fn erfc_inv(x: f64) -> f64 {
if x <= 0.0 {
f64::INFINITY
} else if x >= 2.0 {
f64::NEG_INFINITY
} else if x > 1.0 {
erf_inv_impl(-1.0 + x, 2.0 - x, -1.0)
} else {
erf_inv_impl(1.0 - x, x, 1.0)
}
}
// **********************************************************
// ********** Coefficients for erf_impl polynomial **********
// **********************************************************
/// Polynomial coefficients for a numerator of `erf_impl`
/// in the interval [1e-10, 0.5].
const ERF_IMPL_AN: &[f64] = &[
0.00337916709551257388990745,
-0.00073695653048167948530905,
-0.374732337392919607868241,
0.0817442448733587196071743,
-0.0421089319936548595203468,
0.0070165709512095756344528,
-0.00495091255982435110337458,
0.000871646599037922480317225,
];
/// Polynomial coefficients for a denominator of `erf_impl`
/// in the interval [1e-10, 0.5]
const ERF_IMPL_AD: &[f64] = &[
1.0,
-0.218088218087924645390535,
0.412542972725442099083918,
-0.0841891147873106755410271,
0.0655338856400241519690695,
-0.0120019604454941768171266,
0.00408165558926174048329689,
-0.000615900721557769691924509,
];
/// Polynomial coefficients for a numerator in `erf_impl`
/// in the interval [0.5, 0.75].
const ERF_IMPL_BN: &[f64] = &[
-0.0361790390718262471360258,
0.292251883444882683221149,
0.281447041797604512774415,
0.125610208862766947294894,
0.0274135028268930549240776,
0.00250839672168065762786937,
];
/// Polynomial coefficients for a denominator in `erf_impl`
/// in the interval [0.5, 0.75].
const ERF_IMPL_BD: &[f64] = &[
1.0,
1.8545005897903486499845,
1.43575803037831418074962,
0.582827658753036572454135,
0.124810476932949746447682,
0.0113724176546353285778481,
];
/// Polynomial coefficients for a numerator in `erf_impl`
/// in the interval [0.75, 1.25].
const ERF_IMPL_CN: &[f64] = &[
-0.0397876892611136856954425,
0.153165212467878293257683,
0.191260295600936245503129,
0.10276327061989304213645,
0.029637090615738836726027,
0.0046093486780275489468812,
0.000307607820348680180548455,
];
/// Polynomial coefficients for a denominator in `erf_impl`
/// in the interval [0.75, 1.25].
const ERF_IMPL_CD: &[f64] = &[
1.0,
1.95520072987627704987886,
1.64762317199384860109595,
0.768238607022126250082483,
0.209793185936509782784315,
0.0319569316899913392596356,
0.00213363160895785378615014,
];
/// Polynomial coefficients for a numerator in `erf_impl`
/// in the interval [1.25, 2.25].
const ERF_IMPL_DN: &[f64] = &[
-0.0300838560557949717328341,
0.0538578829844454508530552,
0.0726211541651914182692959,
0.0367628469888049348429018,
0.00964629015572527529605267,
0.00133453480075291076745275,
0.778087599782504251917881e-4,
];
/// Polynomial coefficients for a denominator in `erf_impl`
/// in the interval [1.25, 2.25].
const ERF_IMPL_DD: &[f64] = &[
1.0,
1.75967098147167528287343,
1.32883571437961120556307,
0.552528596508757581287907,
0.133793056941332861912279,
0.0179509645176280768640766,
0.00104712440019937356634038,
-0.106640381820357337177643e-7,
];
/// Polynomial coefficients for a numerator in `erf_impl`
/// in the interval [2.25, 3.5].
const ERF_IMPL_EN: &[f64] = &[
-0.0117907570137227847827732,
0.014262132090538809896674,
0.0202234435902960820020765,
0.00930668299990432009042239,
0.00213357802422065994322516,
0.00025022987386460102395382,
0.120534912219588189822126e-4,
];
/// Polynomial coefficients for a denominator in `erf_impl`
/// in the interval [2.25, 3.5].
const ERF_IMPL_ED: &[f64] = &[
1.0,
1.50376225203620482047419,
0.965397786204462896346934,
0.339265230476796681555511,
0.0689740649541569716897427,
0.00771060262491768307365526,
0.000371421101531069302990367,
];
/// Polynomial coefficients for a numerator in `erf_impl`
/// in the interval [3.5, 5.25].
const ERF_IMPL_FN: &[f64] = &[
-0.00546954795538729307482955,
0.00404190278731707110245394,
0.0054963369553161170521356,
0.00212616472603945399437862,
0.000394984014495083900689956,
0.365565477064442377259271e-4,
0.135485897109932323253786e-5,
];
/// Polynomial coefficients for a denominator in `erf_impl`
/// in the interval [3.5, 5.25].
const ERF_IMPL_FD: &[f64] = &[
1.0,
1.21019697773630784832251,
0.620914668221143886601045,
0.173038430661142762569515,
0.0276550813773432047594539,
0.00240625974424309709745382,
0.891811817251336577241006e-4,
-0.465528836283382684461025e-11,
];
/// Polynomial coefficients for a numerator in `erf_impl`
/// in the interval [5.25, 8].
const ERF_IMPL_GN: &[f64] = &[
-0.00270722535905778347999196,
0.0013187563425029400461378,
0.00119925933261002333923989,
0.00027849619811344664248235,
0.267822988218331849989363e-4,
0.923043672315028197865066e-6,
];
/// Polynomial coefficients for a denominator in `erf_impl`
/// in the interval [5.25, 8].
const ERF_IMPL_GD: &[f64] = &[
1.0,
0.814632808543141591118279,
0.268901665856299542168425,
0.0449877216103041118694989,
0.00381759663320248459168994,
0.000131571897888596914350697,
0.404815359675764138445257e-11,
];
/// Polynomial coefficients for a numerator in `erf_impl`
/// in the interval [8, 11.5].
const ERF_IMPL_HN: &[f64] = &[
-0.00109946720691742196814323,
0.000406425442750422675169153,
0.000274499489416900707787024,
0.465293770646659383436343e-4,
0.320955425395767463401993e-5,
0.778286018145020892261936e-7,
];
/// Polynomial coefficients for a denominator in `erf_impl`
/// in the interval [8, 11.5].
const ERF_IMPL_HD: &[f64] = &[
1.0,
0.588173710611846046373373,
0.139363331289409746077541,
0.0166329340417083678763028,
0.00100023921310234908642639,
0.24254837521587225125068e-4,
];
/// Polynomial coefficients for a numerator in `erf_impl`
/// in the interval [11.5, 17].
const ERF_IMPL_IN: &[f64] = &[
-0.00056907993601094962855594,
0.000169498540373762264416984,
0.518472354581100890120501e-4,
0.382819312231928859704678e-5,
0.824989931281894431781794e-7,
];
/// Polynomial coefficients for a denominator in `erf_impl`
/// in the interval [11.5, 17].
const ERF_IMPL_ID: &[f64] = &[
1.0,
0.339637250051139347430323,
0.043472647870310663055044,
0.00248549335224637114641629,
0.535633305337152900549536e-4,
-0.117490944405459578783846e-12,
];
/// Polynomial coefficients for a numerator in `erf_impl`
/// in the interval [17, 24].
const ERF_IMPL_JN: &[f64] = &[
-0.000241313599483991337479091,
0.574224975202501512365975e-4,
0.115998962927383778460557e-4,
0.581762134402593739370875e-6,
0.853971555085673614607418e-8,
];
/// Polynomial coefficients for a denominator in `erf_impl`
/// in the interval [17, 24].
const ERF_IMPL_JD: &[f64] = &[
1.0,
0.233044138299687841018015,
0.0204186940546440312625597,
0.000797185647564398289151125,
0.117019281670172327758019e-4,
];
/// Polynomial coefficients for a numerator in `erf_impl`
/// in the interval [24, 38].
const ERF_IMPL_KN: &[f64] = &[
-0.000146674699277760365803642,
0.162666552112280519955647e-4,
0.269116248509165239294897e-5,
0.979584479468091935086972e-7,
0.101994647625723465722285e-8,
];
/// Polynomial coefficients for a denominator in `erf_impl`
/// in the interval [24, 38].
const ERF_IMPL_KD: &[f64] = &[
1.0,
0.165907812944847226546036,
0.0103361716191505884359634,
0.000286593026373868366935721,
0.298401570840900340874568e-5,
];
/// Polynomial coefficients for a numerator in `erf_impl`
/// in the interval [38, 60].
const ERF_IMPL_LN: &[f64] = &[
-0.583905797629771786720406e-4,
0.412510325105496173512992e-5,
0.431790922420250949096906e-6,
0.993365155590013193345569e-8,
0.653480510020104699270084e-10,
];
/// Polynomial coefficients for a denominator in `erf_impl`
/// in the interval [38, 60].
const ERF_IMPL_LD: &[f64] = &[
1.0,
0.105077086072039915406159,
0.00414278428675475620830226,
0.726338754644523769144108e-4,
0.477818471047398785369849e-6,
];
/// Polynomial coefficients for a numerator in `erf_impl`
/// in the interval [60, 85].
const ERF_IMPL_MN: &[f64] = &[
-0.196457797609229579459841e-4,
0.157243887666800692441195e-5,
0.543902511192700878690335e-7,
0.317472492369117710852685e-9,
];
/// Polynomial coefficients for a denominator in `erf_impl`
/// in the interval [60, 85].
const ERF_IMPL_MD: &[f64] = &[
1.0,
0.052803989240957632204885,
0.000926876069151753290378112,
0.541011723226630257077328e-5,
0.535093845803642394908747e-15,
];
/// Polynomial coefficients for a numerator in `erf_impl`
/// in the interval [85, 110].
const ERF_IMPL_NN: &[f64] = &[
-0.789224703978722689089794e-5,
0.622088451660986955124162e-6,
0.145728445676882396797184e-7,
0.603715505542715364529243e-10,
];
/// Polynomial coefficients for a denominator in `erf_impl`
/// in the interval [85, 110].
const ERF_IMPL_ND: &[f64] = &[
1.0,
0.0375328846356293715248719,
0.000467919535974625308126054,
0.193847039275845656900547e-5,
];
// **********************************************************
// ********** Coefficients for erf_inv_impl polynomial ******
// **********************************************************
/// Polynomial coefficients for a numerator of `erf_inv_impl`
/// in the interval [0, 0.5].
const ERF_INV_IMPL_AN: &[f64] = &[
-0.000508781949658280665617,
-0.00836874819741736770379,
0.0334806625409744615033,
-0.0126926147662974029034,
-0.0365637971411762664006,
0.0219878681111168899165,
0.00822687874676915743155,
-0.00538772965071242932965,
];
/// Polynomial coefficients for a denominator of `erf_inv_impl`
/// in the interval [0, 0.5].
const ERF_INV_IMPL_AD: &[f64] = &[
1.0,
-0.970005043303290640362,
-1.56574558234175846809,
1.56221558398423026363,
0.662328840472002992063,
-0.71228902341542847553,
-0.0527396382340099713954,
0.0795283687341571680018,
-0.00233393759374190016776,
0.000886216390456424707504,
];
/// Polynomial coefficients for a numerator of `erf_inv_impl`
/// in the interval [0.5, 0.75].
const ERF_INV_IMPL_BN: &[f64] = &[
-0.202433508355938759655,
0.105264680699391713268,
8.37050328343119927838,
17.6447298408374015486,
-18.8510648058714251895,
-44.6382324441786960818,
17.445385985570866523,
21.1294655448340526258,
-3.67192254707729348546,
];
/// Polynomial coefficients for a denominator of `erf_inv_impl`
/// in the interval [0.5, 0.75].
const ERF_INV_IMPL_BD: &[f64] = &[
1.0,
6.24264124854247537712,
3.9713437953343869095,
-28.6608180499800029974,
-20.1432634680485188801,
48.5609213108739935468,
10.8268667355460159008,
-22.6436933413139721736,
1.72114765761200282724,
];
/// Polynomial coefficients for a numerator of `erf_inv_impl`
/// in the interval [0.75, 1] with x less than 3.
const ERF_INV_IMPL_CN: &[f64] = &[
-0.131102781679951906451,
-0.163794047193317060787,
0.117030156341995252019,
0.387079738972604337464,
0.337785538912035898924,
0.142869534408157156766,
0.0290157910005329060432,
0.00214558995388805277169,
-0.679465575181126350155e-6,
0.285225331782217055858e-7,
-0.681149956853776992068e-9,
];
/// Polynomial coefficients for a denominator of `erf_inv_impl`
/// in the interval [0.75, 1] with x less than 3.
const ERF_INV_IMPL_CD: &[f64] = &[
1.0,
3.46625407242567245975,
5.38168345707006855425,
4.77846592945843778382,
2.59301921623620271374,
0.848854343457902036425,
0.152264338295331783612,
0.01105924229346489121,
];
/// Polynomial coefficients for a numerator of `erf_inv_impl`
/// in the interval [0.75, 1] with x between 3 and 6.
const ERF_INV_IMPL_DN: &[f64] = &[
-0.0350353787183177984712,
-0.00222426529213447927281,
0.0185573306514231072324,
0.00950804701325919603619,
0.00187123492819559223345,
0.000157544617424960554631,
0.460469890584317994083e-5,
-0.230404776911882601748e-9,
0.266339227425782031962e-11,
];
/// Polynomial coefficients for a denominator of `erf_inv_impl`
/// in the interval [0.75, 1] with x between 3 and 6.
const ERF_INV_IMPL_DD: &[f64] = &[
1.0,
1.3653349817554063097,
0.762059164553623404043,
0.220091105764131249824,
0.0341589143670947727934,
0.00263861676657015992959,
0.764675292302794483503e-4,
];
/// Polynomial coefficients for a numerator of `erf_inv_impl`
/// in the interval [0.75, 1] with x between 6 and 18.
const ERF_INV_IMPL_EN: &[f64] = &[
-0.0167431005076633737133,
-0.00112951438745580278863,
0.00105628862152492910091,
0.000209386317487588078668,
0.149624783758342370182e-4,
0.449696789927706453732e-6,
0.462596163522878599135e-8,
-0.281128735628831791805e-13,
0.99055709973310326855e-16,
];
/// Polynomial coefficients for a denominator of `erf_inv_impl`
/// in the interval [0.75, 1] with x between 6 and 18.
const ERF_INV_IMPL_ED: &[f64] = &[
1.0,
0.591429344886417493481,
0.138151865749083321638,
0.0160746087093676504695,
0.000964011807005165528527,
0.275335474764726041141e-4,
0.282243172016108031869e-6,
];
/// Polynomial coefficients for a numerator of `erf_inv_impl`
/// in the interval [0.75, 1] with x between 18 and 44.
const ERF_INV_IMPL_FN: &[f64] = &[
-0.0024978212791898131227,
-0.779190719229053954292e-5,
0.254723037413027451751e-4,
0.162397777342510920873e-5,
0.396341011304801168516e-7,
0.411632831190944208473e-9,
0.145596286718675035587e-11,
-0.116765012397184275695e-17,
];
/// Polynomial coefficients for a denominator of `erf_inv_impl`
/// in the interval [0.75, 1] with x between 18 and 44.
const ERF_INV_IMPL_FD: &[f64] = &[
1.0,
0.207123112214422517181,
0.0169410838120975906478,
0.000690538265622684595676,
0.145007359818232637924e-4,
0.144437756628144157666e-6,
0.509761276599778486139e-9,
];
/// Polynomial coefficients for a numerator of `erf_inv_impl`
/// in the interval [0.75, 1] with x greater than 44.
const ERF_INV_IMPL_GN: &[f64] = &[
-0.000539042911019078575891,
-0.28398759004727721098e-6,
0.899465114892291446442e-6,
0.229345859265920864296e-7,
0.225561444863500149219e-9,
0.947846627503022684216e-12,
0.135880130108924861008e-14,
-0.348890393399948882918e-21,
];
/// Polynomial coefficients for a denominator of `erf_inv_impl`
/// in the interval [0.75, 1] with x greater than 44.
const ERF_INV_IMPL_GD: &[f64] = &[
1.0,
0.0845746234001899436914,
0.00282092984726264681981,
0.468292921940894236786e-4,
0.399968812193862100054e-6,
0.161809290887904476097e-8,
0.231558608310259605225e-11,
];
/// `erf_impl` computes the error function at `z`.
/// If `inv` is true, `1 - erf` is calculated as opposed to `erf`
fn erf_impl(z: f64, inv: bool) -> f64 {
if z < 0.0 {
if !inv {
return -erf_impl(-z, false);
}
if z < -0.5 {
return 2.0 - erf_impl(-z, true);
}
return 1.0 + erf_impl(-z, false);
}
let result = if z < 0.5 {
if z < 1e-10 {
z * 1.125 + z * 0.003379167095512573896158903121545171688
} else {
z * 1.125
+ z * evaluate::polynomial(z, ERF_IMPL_AN) / evaluate::polynomial(z, ERF_IMPL_AD)
}
} else if z < 110.0 {
let (r, b) = if z < 0.75 {
(
evaluate::polynomial(z - 0.5, ERF_IMPL_BN)
/ evaluate::polynomial(z - 0.5, ERF_IMPL_BD),
0.3440242112,
)
} else if z < 1.25 {
(
evaluate::polynomial(z - 0.75, ERF_IMPL_CN)
/ evaluate::polynomial(z - 0.75, ERF_IMPL_CD),
0.419990927,
)
} else if z < 2.25 {
(
evaluate::polynomial(z - 1.25, ERF_IMPL_DN)
/ evaluate::polynomial(z - 1.25, ERF_IMPL_DD),
0.4898625016,
)
} else if z < 3.5 {
(
evaluate::polynomial(z - 2.25, ERF_IMPL_EN)
/ evaluate::polynomial(z - 2.25, ERF_IMPL_ED),
0.5317370892,
)
} else if z < 5.25 {
(
evaluate::polynomial(z - 3.5, ERF_IMPL_FN)
/ evaluate::polynomial(z - 3.5, ERF_IMPL_FD),
0.5489973426,
)
} else if z < 8.0 {
(
evaluate::polynomial(z - 5.25, ERF_IMPL_GN)
/ evaluate::polynomial(z - 5.25, ERF_IMPL_GD),
0.5571740866,
)
} else if z < 11.5 {
(
evaluate::polynomial(z - 8.0, ERF_IMPL_HN)
/ evaluate::polynomial(z - 8.0, ERF_IMPL_HD),
0.5609807968,
)
} else if z < 17.0 {
(
evaluate::polynomial(z - 11.5, ERF_IMPL_IN)
/ evaluate::polynomial(z - 11.5, ERF_IMPL_ID),
0.5626493692,
)
} else if z < 24.0 {
(
evaluate::polynomial(z - 17.0, ERF_IMPL_JN)
/ evaluate::polynomial(z - 17.0, ERF_IMPL_JD),
0.5634598136,
)
} else if z < 38.0 {
(
evaluate::polynomial(z - 24.0, ERF_IMPL_KN)
/ evaluate::polynomial(z - 24.0, ERF_IMPL_KD),
0.5638477802,
)
} else if z < 60.0 {
(
evaluate::polynomial(z - 38.0, ERF_IMPL_LN)
/ evaluate::polynomial(z - 38.0, ERF_IMPL_LD),
0.5640528202,
)
} else if z < 85.0 {
(
evaluate::polynomial(z - 60.0, ERF_IMPL_MN)
/ evaluate::polynomial(z - 60.0, ERF_IMPL_MD),
0.5641309023,
)
} else {
(
evaluate::polynomial(z - 85.0, ERF_IMPL_NN)
/ evaluate::polynomial(z - 85.0, ERF_IMPL_ND),
0.5641584396,
)
};
let g = (-z * z).exp() / z;
g * b + g * r
} else {
0.0
};
if inv && z >= 0.5 {
result
} else if z >= 0.5 || inv {
1.0 - result
} else {
result
}
}
// `erf_inv_impl` computes the inverse error function where
// `p`,`q`, and `s` are the first, second, and third intermediate
// parameters respectively
fn erf_inv_impl(p: f64, q: f64, s: f64) -> f64 {
let result = if p <= 0.5 {
let y = 0.0891314744949340820313;
let g = p * (p + 10.0);
let r = evaluate::polynomial(p, ERF_INV_IMPL_AN) / evaluate::polynomial(p, ERF_INV_IMPL_AD);
g * y + g * r
} else if q >= 0.25 {
let y = 2.249481201171875;
let g = (-2.0 * q.ln()).sqrt();
let xs = q - 0.25;
let r =
evaluate::polynomial(xs, ERF_INV_IMPL_BN) / evaluate::polynomial(xs, ERF_INV_IMPL_BD);
g / (y + r)
} else {
let x = (-q.ln()).sqrt();
if x < 3.0 {
let y = 0.807220458984375;
let xs = x - 1.125;
let r = evaluate::polynomial(xs, ERF_INV_IMPL_CN)
/ evaluate::polynomial(xs, ERF_INV_IMPL_CD);
y * x + r * x
} else if x < 6.0 {
let y = 0.93995571136474609375;
let xs = x - 3.0;
let r = evaluate::polynomial(xs, ERF_INV_IMPL_DN)
/ evaluate::polynomial(xs, ERF_INV_IMPL_DD);
y * x + r * x
} else if x < 18.0 {
let y = 0.98362827301025390625;
let xs = x - 6.0;
let r = evaluate::polynomial(xs, ERF_INV_IMPL_EN)
/ evaluate::polynomial(xs, ERF_INV_IMPL_ED);
y * x + r * x
} else if x < 44.0 {
let y = 0.99714565277099609375;
let xs = x - 18.0;
let r = evaluate::polynomial(xs, ERF_INV_IMPL_FN)
/ evaluate::polynomial(xs, ERF_INV_IMPL_FD);
y * x + r * x
} else {
let y = 0.99941349029541015625;
let xs = x - 44.0;
let r = evaluate::polynomial(xs, ERF_INV_IMPL_GN)
/ evaluate::polynomial(xs, ERF_INV_IMPL_GD);
y * x + r * x
}
};
s * result
}

View File

@ -1,3 +1,4 @@
pub mod erf;
pub mod kernels;
trait Cpu<const ARR: usize> {

View File

@ -58,6 +58,7 @@ pub enum UnaryOp {
Sqr,
Sqrt,
Gelu,
GeluErf,
Relu,
Tanh,
}
@ -325,6 +326,7 @@ pub(crate) struct Recip;
pub(crate) struct Sqr;
pub(crate) struct Sqrt;
pub(crate) struct Gelu;
pub(crate) struct GeluErf;
pub(crate) struct Relu;
pub(crate) struct Tanh;
@ -621,6 +623,40 @@ impl UnaryOpT for Gelu {
}
}
impl UnaryOpT for GeluErf {
const NAME: &'static str = "gelu_erf";
const KERNEL: &'static str = "ugelu_erf";
const V: Self = GeluErf;
#[inline(always)]
fn bf16(v: bf16) -> bf16 {
bf16::from_f64(Self::f64(v.to_f64()))
}
#[inline(always)]
fn f16(v: f16) -> f16 {
f16::from_f64(Self::f64(v.to_f64()))
}
#[inline(always)]
fn f32(v: f32) -> f32 {
Self::f64(v as f64) as f32
}
#[inline(always)]
fn f64(v: f64) -> f64 {
(crate::cpu::erf::erf(v / 2f64.sqrt()) + 1.) * 0.5 * v
}
#[inline(always)]
fn u8(_: u8) -> u8 {
0
}
#[inline(always)]
fn u32(_: u32) -> u32 {
0
}
#[inline(always)]
fn i64(_: i64) -> i64 {
0
}
}
impl UnaryOpT for Relu {
const NAME: &'static str = "relu";
const KERNEL: &'static str = "urelu";

View File

@ -489,6 +489,7 @@ impl Tensor {
unary_op!(sqr, Sqr);
unary_op!(sqrt, Sqrt);
unary_op!(gelu, Gelu);
unary_op!(gelu_erf, GeluErf);
unary_op!(relu, Relu);
/// Retrieves the single scalar value hold in the tensor. If the tensor contains multiple

View File

@ -1,4 +1,4 @@
use candle_core::{test_device, DType, Device, IndexOp, Result, Tensor};
use candle_core::{test_device, test_utils, DType, Device, IndexOp, Result, Tensor};
fn zeros(device: &Device) -> Result<()> {
let tensor = Tensor::zeros((5, 2), DType::F32, device)?;
@ -44,6 +44,26 @@ fn clamp(device: &Device) -> Result<()> {
Ok(())
}
fn unary_op(device: &Device) -> Result<()> {
let data = &[[-3f32, 1., 4., -0.1, 0.5], [2.7, -1.8, -0.28, 1.8, 2.8]];
let tensor = Tensor::new(data, device)?;
assert_eq!(
test_utils::to_vec2_round(&tensor.gelu()?, 4)?,
[
[-0.0036, 0.8412, 3.9999, -0.046, 0.3457],
[2.6911, -0.0647, -0.1091, 1.7353, 2.7933]
]
);
assert_eq!(
test_utils::to_vec2_round(&tensor.gelu_erf()?, 4)?,
[
[-0.004, 0.8413, 3.9999, -0.046, 0.3457],
[2.6906, -0.0647, -0.1091, 1.7353, 2.7928]
]
);
Ok(())
}
fn binary_op(device: &Device) -> Result<()> {
let data = &[[3f32, 1., 4., 1., 5.], [2., 1., 7., 8., 2.]];
let tensor1 = Tensor::new(data, device)?;
@ -908,6 +928,7 @@ test_device!(max, max_cpu, max_gpu);
test_device!(argmax, argmax_cpu, argmax_gpu);
test_device!(argmin, argmin_cpu, argmin_gpu);
test_device!(transpose, transpose_cpu, transpose_gpu);
test_device!(unary_op, unary_op_cpu, unary_op_gpu);
test_device!(binary_op, binary_op_cpu, binary_op_gpu);
test_device!(embeddings, embeddings_cpu, embeddings_gpu);
test_device!(cmp, cmp_cpu, cmp_gpu);

View File

@ -129,6 +129,10 @@ __device__ __forceinline__ float powg(float a, float b) { return powf(a, b); }
__device__ __forceinline__ double powg(double a, double b) { return pow(a, b); }
__device__ __forceinline__ float tanhg(float a) { return tanhf(a); }
__device__ __forceinline__ double tanhg(double a) { return tanh(a); }
__device__ __forceinline__ float erfg(float a) { return erff(a); }
__device__ __forceinline__ double erfg(double a) { return erf(a); }
__device__ __forceinline__ float normcdfg(float a) { return normcdff(a); }
__device__ __forceinline__ double normcdfg(double a) { return normcdf(a); }
__device__ __forceinline__ float maxg(float a, float b) { return fmaxf(a, b); }
__device__ __forceinline__ double maxg(double a, double b) { return fmax(a, b); }
__device__ __forceinline__ float ming(float a, float b) { return fminf(a, b); }
@ -157,6 +161,8 @@ __device__ __forceinline__ __half sing(__half a) { return hsin(a); }
__device__ __forceinline__ __half recipg(__half a) { __half one = 1.0; return one / a; }
__device__ __forceinline__ __half maxg(__half a, __half b) { return __hmax_nan(a, b); }
__device__ __forceinline__ __half tanhg(__half a) { return __float2half(tanhf(__half2float(a))); }
__device__ __forceinline__ __half erfg(__half a) { return __float2half(erff(__half2float(a))); }
__device__ __forceinline__ __half normcdfg(__half a) { return __float2half(normcdff(__half2float(a))); }
__device__ __forceinline__ __half ming(__half a, __half b) { return __hmin_nan(a, b); }
__device__ __forceinline__ __half logg(__half a) { return hlog(a); }
__device__ __forceinline__ __half expg(__half a) { return hexp(a); }
@ -173,6 +179,8 @@ __device__ __forceinline__ __nv_bfloat16 sing(__nv_bfloat16 a) { return hsin(a);
__device__ __forceinline__ __nv_bfloat16 recipg(__nv_bfloat16 a) { __nv_bfloat16 one = 1.0; return one / a; }
__device__ __forceinline__ __nv_bfloat16 maxg(__nv_bfloat16 a, __nv_bfloat16 b) { return __hmax_nan(a, b); }
__device__ __forceinline__ __nv_bfloat16 tanhg(__nv_bfloat16 a) { return __float2bfloat16(tanhf(__bfloat162float(a))); }
__device__ __forceinline__ __nv_bfloat16 erfg(__nv_bfloat16 a) { return __float2bfloat16(erff(__bfloat162float(a))); }
__device__ __forceinline__ __nv_bfloat16 normcdfg(__nv_bfloat16 a) { return __float2bfloat16(normcdff(__bfloat162float(a))); }
__device__ __forceinline__ __nv_bfloat16 ming(__nv_bfloat16 a, __nv_bfloat16 b) { return __hmin_nan(a, b); }
__device__ __forceinline__ __nv_bfloat16 logg(__nv_bfloat16 a) { return hlog(a); }
__device__ __forceinline__ __nv_bfloat16 expg(__nv_bfloat16 a) { return hexp(a); }

View File

@ -28,6 +28,11 @@ extern "C" __global__ void FN_NAME( \
} \
} \
template<typename T>
__device__ __forceinline__ T gelu_erf_fwd(T x) {
return x * normcdfg(x);
}
template<typename T>
__device__ __forceinline__ T gelu_fwd(T x) {
T x_sq = x * x;
@ -86,10 +91,13 @@ UNARY_OP(__nv_bfloat16, ulog_bf16, logg(x))
UNARY_OP(__nv_bfloat16, usin_bf16, sing(x))
UNARY_OP(__nv_bfloat16, ucos_bf16, cosg(x))
UNARY_OP(__nv_bfloat16, utanh_bf16, tanhg(x))
UNARY_OP(__nv_bfloat16, uerf_bf16, erfg(x))
UNARY_OP(__nv_bfloat16, unormcdf_bf16, normcdfg(x))
UNARY_OP(__nv_bfloat16, uabs_bf16, absg(x))
UNARY_OP(__nv_bfloat16, usqr_bf16, x*x)
UNARY_OP(__nv_bfloat16, usqrt_bf16, sqrtg(x))
UNARY_OP(__nv_bfloat16, ugelu_bf16, gelu_fwd(x))
UNARY_OP(__nv_bfloat16, ugelu_erf_bf16, gelu_erf_fwd(x))
UNARY_OP(__nv_bfloat16, urelu_bf16, relu_fwd(x))
UNARY_OP1(__nv_bfloat16, uelu_bf16, elu_fwd(x, param))
UNARY_OP1(__nv_bfloat16, upowf_bf16, powg(x, param))
@ -104,10 +112,13 @@ UNARY_OP(__half, ulog_f16, logg(x))
UNARY_OP(__half, usin_f16, sing(x))
UNARY_OP(__half, ucos_f16, cosg(x))
UNARY_OP(__half, utanh_f16, tanhg(x))
UNARY_OP(__half, uerf_f16, erfg(x))
UNARY_OP(__half, unormcdf_f16, normcdfg(x))
UNARY_OP(__half, uabs_f16, absg(x))
UNARY_OP(__half, usqr_f16, x*x)
UNARY_OP(__half, usqrt_f16, sqrtg(x))
UNARY_OP(__half, ugelu_f16, gelu_fwd(x))
UNARY_OP(__half, ugelu_erf_f16, gelu_erf_fwd(x))
UNARY_OP(__half, urelu_f16, relu_fwd(x))
UNARY_OP1(__half, uelu_f16, elu_fwd(x, param))
UNARY_OP1(__half, upowf_f16, powg(x, param))
@ -131,6 +142,10 @@ UNARY_OP(float, ucos_f32, cosg(x))
UNARY_OP(double, ucos_f64, cosg(x))
UNARY_OP(float, utanh_f32, tanhg(x))
UNARY_OP(double, utanh_f64, tanhg(x))
UNARY_OP(float, uerf_f32, erfg(x))
UNARY_OP(double, uerf_f64, erfg(x))
UNARY_OP(float, unormcdf_f32, normcdfg(x))
UNARY_OP(double, unormcdf_f64, normcdfg(x))
UNARY_OP(float, uabs_f32, absg(x))
UNARY_OP(double, uabs_f64, absg(x))
UNARY_OP(float, usqr_f32, x*x)
@ -139,6 +154,8 @@ UNARY_OP(float, usqrt_f32, sqrtg(x))
UNARY_OP(double, usqrt_f64, sqrtg(x))
UNARY_OP(float, ugelu_f32, gelu_fwd(x))
UNARY_OP(double, ugelu_f64, gelu_fwd(x))
UNARY_OP(float, ugelu_erf_f32, gelu_erf_fwd(x))
UNARY_OP(double, ugelu_erf_f64, gelu_erf_fwd(x))
UNARY_OP(float, urelu_f32, relu_fwd(x))
UNARY_OP(double, urelu_f64, relu_fwd(x))
UNARY_OP1(float, uelu_f32, elu_fwd(x, param))