mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 03:54:56 +00:00
Softmax implementation for cuda. (#747)
This commit is contained in:
@ -1,7 +1,7 @@
|
||||
use crate::backend::{BackendDevice, BackendStorage};
|
||||
use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT};
|
||||
use crate::{CpuStorage, DType, Layout, Result, Shape, WithDType};
|
||||
use candle_kernels as kernels;
|
||||
pub use candle_kernels as kernels;
|
||||
pub use cudarc;
|
||||
use cudarc::cublas::{Gemm, GemmConfig, StridedBatchedConfig};
|
||||
use cudarc::driver::{
|
||||
@ -383,7 +383,7 @@ impl BackendDevice for CudaDevice {
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
enum CudaStorageSlice {
|
||||
pub enum CudaStorageSlice {
|
||||
U8(CudaSlice<u8>),
|
||||
U32(CudaSlice<u32>),
|
||||
I64(CudaSlice<i64>),
|
||||
@ -394,7 +394,7 @@ enum CudaStorageSlice {
|
||||
}
|
||||
type S = CudaStorageSlice;
|
||||
|
||||
trait Map1 {
|
||||
pub trait Map1 {
|
||||
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
|
||||
&self,
|
||||
src: &CudaSlice<T>,
|
||||
@ -416,7 +416,7 @@ trait Map1 {
|
||||
}
|
||||
}
|
||||
|
||||
trait Map2 {
|
||||
pub trait Map2 {
|
||||
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
|
||||
&self,
|
||||
src1: &CudaSlice<T>,
|
||||
@ -441,7 +441,7 @@ trait Map2 {
|
||||
}
|
||||
}
|
||||
|
||||
trait Map2InPlace {
|
||||
pub trait Map2InPlace {
|
||||
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
|
||||
&self,
|
||||
dst: &mut CudaSlice<T>,
|
||||
@ -472,7 +472,7 @@ trait Map2InPlace {
|
||||
}
|
||||
}
|
||||
|
||||
trait Map1Any {
|
||||
pub trait Map1Any {
|
||||
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits, W: Fn(CudaSlice<T>) -> S>(
|
||||
&self,
|
||||
src: &CudaSlice<T>,
|
||||
@ -495,7 +495,7 @@ trait Map1Any {
|
||||
}
|
||||
}
|
||||
|
||||
trait Map2Any {
|
||||
pub trait Map2Any {
|
||||
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
|
||||
&self,
|
||||
src1: &CudaSlice<T>,
|
||||
@ -532,7 +532,7 @@ impl Map1 for Clone {
|
||||
}
|
||||
}
|
||||
|
||||
fn kernel_name<T: WithDType>(root: &str) -> String {
|
||||
pub fn kernel_name<T: WithDType>(root: &str) -> String {
|
||||
let dtype = T::DTYPE.as_str();
|
||||
format!("{root}_{dtype}")
|
||||
}
|
||||
@ -1310,8 +1310,8 @@ fn slice_src_and_dst<'a, T>(
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct CudaStorage {
|
||||
slice: CudaStorageSlice,
|
||||
device: CudaDevice,
|
||||
pub slice: CudaStorageSlice,
|
||||
pub device: CudaDevice,
|
||||
}
|
||||
|
||||
pub trait CudaDType: Sized {
|
||||
|
Reference in New Issue
Block a user