Add a cuda kernel for embeddings.

This commit is contained in:
laurent
2023-06-26 11:47:57 +01:00
parent 5952c3fa91
commit 16f0f5b9d2
5 changed files with 101 additions and 11 deletions

View File

@ -46,6 +46,7 @@ macro_rules! with_dtype {
_ => Err(Error::UnexpectedDType {
expected: DType::$dtype,
got: s.dtype(),
msg: "unexpected dtype",
}),
}
}
@ -56,6 +57,7 @@ macro_rules! with_dtype {
_ => Err(Error::UnexpectedDType {
expected: DType::$dtype,
got: s.dtype(),
msg: "unexpected dtype",
}),
}
}
@ -66,6 +68,7 @@ macro_rules! with_dtype {
_ => Err(Error::UnexpectedDType {
expected: DType::$dtype,
got: s.dtype(),
msg: "unexpected dtype",
}),
}
}