mirror of
https://github.com/huggingface/candle.git
synced 2025-06-20 20:09:50 +00:00
Fixes + cache compute_pipeline_state.
This commit is contained in:
@ -86,7 +86,6 @@ impl MetalDevice {
|
|||||||
|
|
||||||
pub fn new_buffer(&self, element_count: usize, dtype: DType) -> Buffer {
|
pub fn new_buffer(&self, element_count: usize, dtype: DType) -> Buffer {
|
||||||
let size = (element_count * dtype.size_in_bytes()) as NSUInteger;
|
let size = (element_count * dtype.size_in_bytes()) as NSUInteger;
|
||||||
// debug!("Allocate 1 - buffer size {size}");
|
|
||||||
self.device
|
self.device
|
||||||
.new_buffer(size, MTLResourceOptions::StorageModeManaged)
|
.new_buffer(size, MTLResourceOptions::StorageModeManaged)
|
||||||
}
|
}
|
||||||
@ -115,7 +114,9 @@ impl BackendStorage for MetalStorage {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn to_cpu_storage(&self) -> Result<CpuStorage> {
|
fn to_cpu_storage(&self) -> Result<CpuStorage> {
|
||||||
|
let start = std::time::Instant::now();
|
||||||
self.device.wait_until_completed();
|
self.device.wait_until_completed();
|
||||||
|
println!("Wait took {:?}", start.elapsed());
|
||||||
|
|
||||||
match self.dtype {
|
match self.dtype {
|
||||||
DType::U8 => Ok(CpuStorage::U8(
|
DType::U8 => Ok(CpuStorage::U8(
|
||||||
@ -414,6 +415,7 @@ impl BackendStorage for MetalStorage {
|
|||||||
)
|
)
|
||||||
.map_err(MetalError::from)?;
|
.map_err(MetalError::from)?;
|
||||||
}
|
}
|
||||||
|
self.device.wait_until_completed();
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
buffer,
|
buffer,
|
||||||
device: device.clone(),
|
device: device.clone(),
|
||||||
@ -899,9 +901,12 @@ impl BackendDevice for MetalDevice {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn zeros_impl(&self, shape: &Shape, dtype: DType) -> Result<MetalStorage> {
|
fn zeros_impl(&self, shape: &Shape, dtype: DType) -> Result<MetalStorage> {
|
||||||
// TODO Is there a faster way ?
|
let buffer = self.new_buffer(shape.elem_count(), dtype);
|
||||||
let cpu_storage = crate::cpu_backend::CpuDevice.zeros_impl(shape, dtype)?;
|
Ok(MetalStorage {
|
||||||
self.storage_from_cpu_storage(&cpu_storage)
|
buffer,
|
||||||
|
device: self.clone(),
|
||||||
|
dtype,
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
fn ones_impl(&self, shape: &Shape, dtype: DType) -> Result<Self::Storage> {
|
fn ones_impl(&self, shape: &Shape, dtype: DType) -> Result<Self::Storage> {
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
#![allow(clippy::too_many_arguments)]
|
#![allow(clippy::too_many_arguments)]
|
||||||
use metal::{
|
use metal::{
|
||||||
Buffer, CommandBufferRef, CompileOptions, ComputeCommandEncoderRef, ComputePipelineDescriptor,
|
Buffer, CommandBufferRef, CompileOptions, ComputeCommandEncoderRef, ComputePipelineState,
|
||||||
ComputePipelineState, Device, Function, Library, MTLSize,
|
Device, Function, Library, MTLSize,
|
||||||
};
|
};
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use std::ffi::c_void;
|
use std::ffi::c_void;
|
||||||
@ -174,6 +174,10 @@ pub enum MetalKernelError {
|
|||||||
LoadLibraryError(String),
|
LoadLibraryError(String),
|
||||||
#[error("Error while loading function: {0:?}")]
|
#[error("Error while loading function: {0:?}")]
|
||||||
LoadFunctionError(String),
|
LoadFunctionError(String),
|
||||||
|
#[error("Failed to create compute function")]
|
||||||
|
FailedToCreateComputeFunction,
|
||||||
|
#[error("Failed to create pipeline")]
|
||||||
|
FailedToCreatePipeline(String),
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<T> From<std::sync::PoisonError<T>> for MetalKernelError {
|
impl<T> From<std::sync::PoisonError<T>> for MetalKernelError {
|
||||||
@ -184,19 +188,22 @@ impl<T> From<std::sync::PoisonError<T>> for MetalKernelError {
|
|||||||
|
|
||||||
type KernelMap<T> = HashMap<&'static str, T>;
|
type KernelMap<T> = HashMap<&'static str, T>;
|
||||||
type Libraries = HashMap<Source, Library>;
|
type Libraries = HashMap<Source, Library>;
|
||||||
type Functions = KernelMap<Function>;
|
type Pipelines = KernelMap<ComputePipelineState>;
|
||||||
|
|
||||||
#[derive(Debug, Default)]
|
#[derive(Debug)]
|
||||||
pub struct Kernels {
|
pub struct Kernels {
|
||||||
libraries: RwLock<Libraries>,
|
libraries: RwLock<Libraries>,
|
||||||
funcs: RwLock<Functions>,
|
pipelines: RwLock<Pipelines>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Kernels {
|
impl Kernels {
|
||||||
pub fn new() -> Self {
|
pub fn new() -> Self {
|
||||||
let libraries = RwLock::new(Libraries::new());
|
let libraries = RwLock::new(Libraries::new());
|
||||||
let funcs = RwLock::new(Functions::new());
|
let pipelines = RwLock::new(Pipelines::new());
|
||||||
Self { libraries, funcs }
|
Self {
|
||||||
|
libraries,
|
||||||
|
pipelines,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// pub fn init(device: &Device) -> Result<Self, MetalKernelError> {
|
// pub fn init(device: &Device) -> Result<Self, MetalKernelError> {
|
||||||
@ -243,22 +250,43 @@ impl Kernels {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn load_function(
|
fn load_function(
|
||||||
&self,
|
&self,
|
||||||
device: &Device,
|
device: &Device,
|
||||||
source: Source,
|
source: Source,
|
||||||
name: &'static str,
|
name: &'static str,
|
||||||
) -> Result<Function, MetalKernelError> {
|
) -> Result<Function, MetalKernelError> {
|
||||||
let mut funcs = self.funcs.write()?;
|
let func = self
|
||||||
if let Some(func) = funcs.get(name) {
|
.load_library(device, source)?
|
||||||
Ok(func.clone())
|
.get_function(name, None)
|
||||||
|
.map_err(|e| MetalKernelError::LoadFunctionError(e.to_string()))?;
|
||||||
|
Ok(func)
|
||||||
|
// let mut funcs = self.funcs.write()?;
|
||||||
|
// if let Some(func) = funcs.get(name) {
|
||||||
|
// Ok(func.clone())
|
||||||
|
// } else {
|
||||||
|
// funcs.insert(name, func.clone());
|
||||||
|
// Ok(func)
|
||||||
|
// }
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn load_pipeline(
|
||||||
|
&self,
|
||||||
|
device: &Device,
|
||||||
|
source: Source,
|
||||||
|
name: &'static str,
|
||||||
|
) -> Result<ComputePipelineState, MetalKernelError> {
|
||||||
|
let mut pipelines = self.pipelines.write()?;
|
||||||
|
if let Some(pipeline) = pipelines.get(name) {
|
||||||
|
Ok(pipeline.clone())
|
||||||
} else {
|
} else {
|
||||||
let func = self
|
let func = self.load_function(device, source, name)?;
|
||||||
.load_library(device, source)?
|
let pipeline = device
|
||||||
.get_function(name, None)
|
.new_compute_pipeline_state_with_function(&func)
|
||||||
.map_err(|e| MetalKernelError::LoadFunctionError(e.to_string()))?;
|
.map_err(|e| MetalKernelError::FailedToCreatePipeline(e.to_string()))?;
|
||||||
funcs.insert(name, func.clone());
|
pipelines.insert(name, pipeline.clone());
|
||||||
Ok(func)
|
|
||||||
|
Ok(pipeline)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -274,16 +302,8 @@ pub fn call_unary_contiguous(
|
|||||||
) -> Result<(), MetalKernelError> {
|
) -> Result<(), MetalKernelError> {
|
||||||
// println!("Kernel {:?}", kernel_name.0);
|
// println!("Kernel {:?}", kernel_name.0);
|
||||||
// assert_eq!(input.length(), output.length());
|
// assert_eq!(input.length(), output.length());
|
||||||
let func = kernels.load_function(device, Source::Unary, kernel_name.0)?;
|
|
||||||
let pipeline_state_descriptor = ComputePipelineDescriptor::new();
|
|
||||||
pipeline_state_descriptor.set_compute_function(Some(&func));
|
|
||||||
|
|
||||||
let pipeline = device
|
|
||||||
.new_compute_pipeline_state_with_function(
|
|
||||||
pipeline_state_descriptor.compute_function().unwrap(),
|
|
||||||
)
|
|
||||||
.unwrap();
|
|
||||||
|
|
||||||
|
let pipeline = kernels.load_pipeline(device, Source::Unary, kernel_name.0)?;
|
||||||
let encoder = command_buffer.new_compute_command_encoder();
|
let encoder = command_buffer.new_compute_command_encoder();
|
||||||
encoder.set_compute_pipeline_state(&pipeline);
|
encoder.set_compute_pipeline_state(&pipeline);
|
||||||
|
|
||||||
@ -306,15 +326,7 @@ pub fn call_unary_strided(
|
|||||||
output: &mut Buffer,
|
output: &mut Buffer,
|
||||||
output_offset: usize,
|
output_offset: usize,
|
||||||
) -> Result<(), MetalKernelError> {
|
) -> Result<(), MetalKernelError> {
|
||||||
let func = kernels.load_function(device, Source::Unary, name.0)?;
|
let pipeline = kernels.load_pipeline(device, Source::Unary, name.0)?;
|
||||||
let pipeline_state_descriptor = ComputePipelineDescriptor::new();
|
|
||||||
pipeline_state_descriptor.set_compute_function(Some(&func));
|
|
||||||
|
|
||||||
let pipeline = device
|
|
||||||
.new_compute_pipeline_state_with_function(
|
|
||||||
pipeline_state_descriptor.compute_function().unwrap(),
|
|
||||||
)
|
|
||||||
.unwrap();
|
|
||||||
|
|
||||||
let num_dims: usize = shape.len();
|
let num_dims: usize = shape.len();
|
||||||
let encoder = command_buffer.new_compute_command_encoder();
|
let encoder = command_buffer.new_compute_command_encoder();
|
||||||
@ -351,17 +363,7 @@ pub fn call_binary_contiguous(
|
|||||||
right: &Buffer,
|
right: &Buffer,
|
||||||
output: &mut Buffer,
|
output: &mut Buffer,
|
||||||
) -> Result<(), MetalKernelError> {
|
) -> Result<(), MetalKernelError> {
|
||||||
// println!("Kernel {:?}", kernel_name.0);
|
let pipeline = kernels.load_pipeline(device, Source::Binary, kernel_name.0)?;
|
||||||
// assert_eq!(input.length(), output.length());
|
|
||||||
let func = kernels.load_function(device, Source::Binary, kernel_name.0)?;
|
|
||||||
let pipeline_state_descriptor = ComputePipelineDescriptor::new();
|
|
||||||
pipeline_state_descriptor.set_compute_function(Some(&func));
|
|
||||||
|
|
||||||
let pipeline = device
|
|
||||||
.new_compute_pipeline_state_with_function(
|
|
||||||
pipeline_state_descriptor.compute_function().unwrap(),
|
|
||||||
)
|
|
||||||
.unwrap();
|
|
||||||
|
|
||||||
let encoder = command_buffer.new_compute_command_encoder();
|
let encoder = command_buffer.new_compute_command_encoder();
|
||||||
encoder.set_compute_pipeline_state(&pipeline);
|
encoder.set_compute_pipeline_state(&pipeline);
|
||||||
@ -389,15 +391,7 @@ pub fn call_binary_strided(
|
|||||||
right_offset: usize,
|
right_offset: usize,
|
||||||
output: &mut Buffer,
|
output: &mut Buffer,
|
||||||
) -> Result<(), MetalKernelError> {
|
) -> Result<(), MetalKernelError> {
|
||||||
let func = kernels.load_function(device, Source::Binary, name.0)?;
|
let pipeline = kernels.load_pipeline(device, Source::Binary, name.0)?;
|
||||||
let pipeline_state_descriptor = ComputePipelineDescriptor::new();
|
|
||||||
pipeline_state_descriptor.set_compute_function(Some(&func));
|
|
||||||
|
|
||||||
let pipeline = device
|
|
||||||
.new_compute_pipeline_state_with_function(
|
|
||||||
pipeline_state_descriptor.compute_function().unwrap(),
|
|
||||||
)
|
|
||||||
.unwrap();
|
|
||||||
|
|
||||||
let num_dims: usize = shape.len();
|
let num_dims: usize = shape.len();
|
||||||
let encoder = command_buffer.new_compute_command_encoder();
|
let encoder = command_buffer.new_compute_command_encoder();
|
||||||
@ -436,17 +430,7 @@ pub fn call_cast_contiguous(
|
|||||||
input: &Buffer,
|
input: &Buffer,
|
||||||
output: &mut Buffer,
|
output: &mut Buffer,
|
||||||
) -> Result<(), MetalKernelError> {
|
) -> Result<(), MetalKernelError> {
|
||||||
// println!("Kernel {:?}", kernel_name.0);
|
let pipeline = kernels.load_pipeline(device, Source::Cast, kernel_name)?;
|
||||||
// assert_eq!(input.length(), output.length());
|
|
||||||
let func = kernels.load_function(device, Source::Cast, kernel_name)?;
|
|
||||||
let pipeline_state_descriptor = ComputePipelineDescriptor::new();
|
|
||||||
pipeline_state_descriptor.set_compute_function(Some(&func));
|
|
||||||
|
|
||||||
let pipeline = device
|
|
||||||
.new_compute_pipeline_state_with_function(
|
|
||||||
pipeline_state_descriptor.compute_function().unwrap(),
|
|
||||||
)
|
|
||||||
.unwrap();
|
|
||||||
|
|
||||||
let encoder = command_buffer.new_compute_command_encoder();
|
let encoder = command_buffer.new_compute_command_encoder();
|
||||||
encoder.set_compute_pipeline_state(&pipeline);
|
encoder.set_compute_pipeline_state(&pipeline);
|
||||||
@ -473,15 +457,7 @@ pub fn call_cast_strided(
|
|||||||
) -> Result<(), MetalKernelError> {
|
) -> Result<(), MetalKernelError> {
|
||||||
// println!("Kernel {:?}", kernel_name.0);
|
// println!("Kernel {:?}", kernel_name.0);
|
||||||
// assert_eq!(input.length(), output.length());
|
// assert_eq!(input.length(), output.length());
|
||||||
let func = kernels.load_function(device, Source::Cast, kernel_name)?;
|
let pipeline = kernels.load_pipeline(device, Source::Cast, kernel_name)?;
|
||||||
let pipeline_state_descriptor = ComputePipelineDescriptor::new();
|
|
||||||
pipeline_state_descriptor.set_compute_function(Some(&func));
|
|
||||||
|
|
||||||
let pipeline = device
|
|
||||||
.new_compute_pipeline_state_with_function(
|
|
||||||
pipeline_state_descriptor.compute_function().unwrap(),
|
|
||||||
)
|
|
||||||
.unwrap();
|
|
||||||
|
|
||||||
let encoder = command_buffer.new_compute_command_encoder();
|
let encoder = command_buffer.new_compute_command_encoder();
|
||||||
encoder.set_compute_pipeline_state(&pipeline);
|
encoder.set_compute_pipeline_state(&pipeline);
|
||||||
@ -510,16 +486,7 @@ pub fn call_reduce_contiguous(
|
|||||||
input: &Buffer,
|
input: &Buffer,
|
||||||
output: &mut Buffer,
|
output: &mut Buffer,
|
||||||
) -> Result<(), MetalKernelError> {
|
) -> Result<(), MetalKernelError> {
|
||||||
let func = kernels.load_function(device, Source::Reduce, kernel_name)?;
|
let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?;
|
||||||
let pipeline_state_descriptor = ComputePipelineDescriptor::new();
|
|
||||||
pipeline_state_descriptor.set_compute_function(Some(&func));
|
|
||||||
|
|
||||||
let pipeline = device
|
|
||||||
.new_compute_pipeline_state_with_function(
|
|
||||||
pipeline_state_descriptor.compute_function().unwrap(),
|
|
||||||
)
|
|
||||||
.unwrap();
|
|
||||||
|
|
||||||
let elements_to_sum = length / out_length;
|
let elements_to_sum = length / out_length;
|
||||||
|
|
||||||
let encoder = command_buffer.new_compute_command_encoder();
|
let encoder = command_buffer.new_compute_command_encoder();
|
||||||
@ -560,16 +527,7 @@ pub fn call_last_softmax(
|
|||||||
input: &Buffer,
|
input: &Buffer,
|
||||||
output: &mut Buffer,
|
output: &mut Buffer,
|
||||||
) -> Result<(), MetalKernelError> {
|
) -> Result<(), MetalKernelError> {
|
||||||
let func = kernels.load_function(device, Source::Reduce, kernel_name)?;
|
let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?;
|
||||||
let pipeline_state_descriptor = ComputePipelineDescriptor::new();
|
|
||||||
pipeline_state_descriptor.set_compute_function(Some(&func));
|
|
||||||
|
|
||||||
let pipeline = device
|
|
||||||
.new_compute_pipeline_state_with_function(
|
|
||||||
pipeline_state_descriptor.compute_function().unwrap(),
|
|
||||||
)
|
|
||||||
.unwrap();
|
|
||||||
|
|
||||||
let encoder = command_buffer.new_compute_command_encoder();
|
let encoder = command_buffer.new_compute_command_encoder();
|
||||||
encoder.set_compute_pipeline_state(&pipeline);
|
encoder.set_compute_pipeline_state(&pipeline);
|
||||||
|
|
||||||
@ -612,15 +570,7 @@ pub fn call_affine(
|
|||||||
mul: f32,
|
mul: f32,
|
||||||
add: f32,
|
add: f32,
|
||||||
) -> Result<(), MetalKernelError> {
|
) -> Result<(), MetalKernelError> {
|
||||||
let func = kernels.load_function(device, Source::Affine, name)?;
|
let pipeline = kernels.load_pipeline(device, Source::Affine, name)?;
|
||||||
let pipeline_state_descriptor = ComputePipelineDescriptor::new();
|
|
||||||
pipeline_state_descriptor.set_compute_function(Some(&func));
|
|
||||||
|
|
||||||
let pipeline = device
|
|
||||||
.new_compute_pipeline_state_with_function(
|
|
||||||
pipeline_state_descriptor.compute_function().unwrap(),
|
|
||||||
)
|
|
||||||
.unwrap();
|
|
||||||
|
|
||||||
let encoder = command_buffer.new_compute_command_encoder();
|
let encoder = command_buffer.new_compute_command_encoder();
|
||||||
encoder.set_compute_pipeline_state(&pipeline);
|
encoder.set_compute_pipeline_state(&pipeline);
|
||||||
@ -646,18 +596,9 @@ pub fn call_affine_strided(
|
|||||||
mul: f32,
|
mul: f32,
|
||||||
add: f32,
|
add: f32,
|
||||||
) -> Result<(), MetalKernelError> {
|
) -> Result<(), MetalKernelError> {
|
||||||
let func = kernels.load_function(device, Source::Affine, name)?;
|
let pipeline = kernels.load_pipeline(device, Source::Affine, name)?;
|
||||||
let pipeline_state_descriptor = ComputePipelineDescriptor::new();
|
|
||||||
pipeline_state_descriptor.set_compute_function(Some(&func));
|
|
||||||
|
|
||||||
let size: usize = shape.iter().product();
|
let size: usize = shape.iter().product();
|
||||||
|
|
||||||
let pipeline = device
|
|
||||||
.new_compute_pipeline_state_with_function(
|
|
||||||
pipeline_state_descriptor.compute_function().unwrap(),
|
|
||||||
)
|
|
||||||
.unwrap();
|
|
||||||
|
|
||||||
let encoder = command_buffer.new_compute_command_encoder();
|
let encoder = command_buffer.new_compute_command_encoder();
|
||||||
encoder.set_compute_pipeline_state(&pipeline);
|
encoder.set_compute_pipeline_state(&pipeline);
|
||||||
|
|
||||||
@ -695,15 +636,7 @@ pub fn call_where_cond_strided(
|
|||||||
(right_stride, right_offset): (&[usize], usize),
|
(right_stride, right_offset): (&[usize], usize),
|
||||||
output: &mut Buffer,
|
output: &mut Buffer,
|
||||||
) -> Result<(), MetalKernelError> {
|
) -> Result<(), MetalKernelError> {
|
||||||
let func = kernels.load_function(device, Source::Ternary, name)?;
|
let pipeline = kernels.load_pipeline(device, Source::Ternary, name)?;
|
||||||
let pipeline_state_descriptor = ComputePipelineDescriptor::new();
|
|
||||||
pipeline_state_descriptor.set_compute_function(Some(&func));
|
|
||||||
|
|
||||||
let pipeline = device
|
|
||||||
.new_compute_pipeline_state_with_function(
|
|
||||||
pipeline_state_descriptor.compute_function().unwrap(),
|
|
||||||
)
|
|
||||||
.unwrap();
|
|
||||||
|
|
||||||
let encoder = command_buffer.new_compute_command_encoder();
|
let encoder = command_buffer.new_compute_command_encoder();
|
||||||
encoder.set_compute_pipeline_state(&pipeline);
|
encoder.set_compute_pipeline_state(&pipeline);
|
||||||
@ -751,10 +684,7 @@ pub fn call_index_select(
|
|||||||
let src_dim_size = shape[dim];
|
let src_dim_size = shape[dim];
|
||||||
let dst_el = ids_size * left_size * right_size;
|
let dst_el = ids_size * left_size * right_size;
|
||||||
|
|
||||||
let func = kernels.load_function(device, Source::Indexing, name)?;
|
let pipeline = kernels.load_pipeline(device, Source::Indexing, name)?;
|
||||||
let pipeline = device
|
|
||||||
.new_compute_pipeline_state_with_function(&func)
|
|
||||||
.unwrap();
|
|
||||||
|
|
||||||
let encoder = command_buffer.new_compute_command_encoder();
|
let encoder = command_buffer.new_compute_command_encoder();
|
||||||
|
|
||||||
|
@ -40,7 +40,7 @@ template <typename T> METAL_FUNC T erf(T in){
|
|||||||
float t = 1.0/(1.0 + p*x);
|
float t = 1.0/(1.0 + p*x);
|
||||||
float y = 1.0 - (((((a5*t + a4)*t) + a3)*t + a2)*t + a1)*t*exp(-x*x);
|
float y = 1.0 - (((((a5*t + a4)*t) + a3)*t + a2)*t + a1)*t*exp(-x*x);
|
||||||
|
|
||||||
return (T) sign*y;
|
return T(sign*y);
|
||||||
}
|
}
|
||||||
template <typename T> METAL_FUNC T id(T in){ return in; }
|
template <typename T> METAL_FUNC T id(T in){ return in; }
|
||||||
template <typename T> METAL_FUNC T gelu_erf(T x){ return x * (1 + erf(x * M_SQRT1_2_F)) / 2; }
|
template <typename T> METAL_FUNC T gelu_erf(T x){ return x * (1 + erf(x * M_SQRT1_2_F)) / 2; }
|
||||||
@ -49,7 +49,7 @@ template <typename T> METAL_FUNC T gelu(T x){
|
|||||||
T x_cube = x_sq * x;
|
T x_cube = x_sq * x;
|
||||||
T alpha = x + static_cast<T>(0.044715) * x_cube;
|
T alpha = x + static_cast<T>(0.044715) * x_cube;
|
||||||
T beta = (static_cast<T>(M_2_SQRTPI_F * M_SQRT1_2_F) * alpha);
|
T beta = (static_cast<T>(M_2_SQRTPI_F * M_SQRT1_2_F) * alpha);
|
||||||
return static_cast<T>(0.5) * x * (static_cast<T>(1.0) + tanh(beta));
|
return static_cast<T>(0.5) * x * (static_cast<T>(1.0) + T(tanh(beta)));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user