mirror of
https://github.com/huggingface/candle.git
synced 2025-06-18 11:37:11 +00:00
Merge pull request #34 from LaurentMazare/simpler-dtype-trait
Put more requirements on the withdtype trait.
This commit is contained in:
@ -15,11 +15,7 @@ pub enum CpuStorage {
|
||||
}
|
||||
|
||||
trait Map1 {
|
||||
fn f<T: WithDType + Copy + num_traits::NumAssign>(
|
||||
&self,
|
||||
vs: &[T],
|
||||
layout: &Layout,
|
||||
) -> Result<Vec<T>>;
|
||||
fn f<T: WithDType>(&self, vs: &[T], layout: &Layout) -> Result<Vec<T>>;
|
||||
|
||||
fn map(&self, vs: &CpuStorage, layout: &Layout) -> Result<CpuStorage> {
|
||||
match vs {
|
||||
@ -35,13 +31,7 @@ trait Map1 {
|
||||
type C = CpuStorage;
|
||||
trait Map2 {
|
||||
const OP: &'static str;
|
||||
fn f<T: WithDType + Copy + num_traits::Num + 'static>(
|
||||
&self,
|
||||
v1: &[T],
|
||||
l1: &Layout,
|
||||
v2: &[T],
|
||||
l2: &Layout,
|
||||
) -> Result<Vec<T>>;
|
||||
fn f<T: WithDType>(&self, v1: &[T], l1: &Layout, v2: &[T], l2: &Layout) -> Result<Vec<T>>;
|
||||
|
||||
fn map(
|
||||
&self,
|
||||
@ -101,11 +91,7 @@ struct Sum<'a> {
|
||||
}
|
||||
|
||||
impl<'a> Map1 for Sum<'a> {
|
||||
fn f<T: WithDType + Copy + num_traits::NumAssign>(
|
||||
&self,
|
||||
src: &[T],
|
||||
src_layout: &Layout,
|
||||
) -> Result<Vec<T>> {
|
||||
fn f<T: WithDType>(&self, src: &[T], src_layout: &Layout) -> Result<Vec<T>> {
|
||||
let mut dst = vec![T::zero(); self.dst_shape.elem_count()];
|
||||
for (unstr_index, src_index) in src_layout.strided_index().enumerate() {
|
||||
let mut dst_index = unstr_index;
|
||||
@ -153,11 +139,7 @@ fn binary_map<T: Copy, F: FnMut(T, T) -> T>(
|
||||
struct Affine(f64, f64);
|
||||
|
||||
impl Map1 for Affine {
|
||||
fn f<T: WithDType + Copy + num_traits::NumAssign>(
|
||||
&self,
|
||||
vs: &[T],
|
||||
layout: &Layout,
|
||||
) -> Result<Vec<T>> {
|
||||
fn f<T: WithDType>(&self, vs: &[T], layout: &Layout) -> Result<Vec<T>> {
|
||||
let mul = T::from_f64(self.0);
|
||||
let add = T::from_f64(self.1);
|
||||
Ok(unary_map(vs, layout, |v| v * mul + add))
|
||||
@ -292,11 +274,7 @@ impl Map2 for MatMul {
|
||||
}
|
||||
}
|
||||
|
||||
fn divide_by_sum_over_dim<T: WithDType + num_traits::NumAssign>(
|
||||
s: &mut [T],
|
||||
shape: &Shape,
|
||||
dim: usize,
|
||||
) -> Result<()> {
|
||||
fn divide_by_sum_over_dim<T: WithDType>(s: &mut [T], shape: &Shape, dim: usize) -> Result<()> {
|
||||
// [self] stores data in a contiguous way starting at offset 0.
|
||||
let dims = shape.dims();
|
||||
let elem_per_slice = dims[dim];
|
||||
@ -332,7 +310,7 @@ impl CpuStorage {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn as_slice<D: crate::WithDType>(&self) -> Result<&[D]> {
|
||||
pub fn as_slice<D: WithDType>(&self) -> Result<&[D]> {
|
||||
D::cpu_storage_as_slice(self)
|
||||
}
|
||||
|
||||
|
@ -31,7 +31,7 @@ impl DType {
|
||||
}
|
||||
}
|
||||
|
||||
pub trait WithDType: Sized + Copy {
|
||||
pub trait WithDType: Sized + Copy + num_traits::NumAssign + 'static {
|
||||
const DTYPE: DType;
|
||||
|
||||
fn from_f64(v: f64) -> Self;
|
||||
|
Reference in New Issue
Block a user