mirror of
https://github.com/huggingface/candle.git
synced 2025-06-18 11:37:11 +00:00
Cuda support for the mnist training. (#277)
* Cuda support for the mnist training. * min/max fix + testing. * Add the argmin/argmax tests. * More cuda support for argmin/argmax. * Cuda kernels for argmin and argmax.
This commit is contained in:
@ -144,7 +144,8 @@ __device__ __forceinline__ double copysigng(double a, double b) { return copysig
|
||||
|
||||
__device__ __forceinline__ uint32_t ming(uint32_t a, uint32_t b) { return min(a, b); }
|
||||
__device__ __forceinline__ uint32_t maxg(uint32_t a, uint32_t b) { return max(a, b); }
|
||||
|
||||
__device__ __forceinline__ uint8_t ming(uint8_t a, uint8_t b) { return min(a, b); }
|
||||
__device__ __forceinline__ uint8_t maxg(uint8_t a, uint8_t b) { return max(a, b); }
|
||||
#if __CUDA_ARCH__ >= 530
|
||||
__device__ __forceinline__ __half powg(__half a, __half b) { return __float2half(powf(__half2float(a), __half2float(b))); }
|
||||
__device__ __forceinline__ bool isnang(__half a) { return __hisnan(a); }
|
||||
|
Reference in New Issue
Block a user