mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Compare commits
3 Commits
0.8.1
...
matmul-slo
Author | SHA1 | Date | |
---|---|---|---|
69c1fb1ee8 | |||
c55ebaf477 | |||
4c91dd2ff4 |
@ -67,6 +67,20 @@ impl DType {
|
||||
Self::F64 => 8,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn is_int(&self) -> bool {
|
||||
match self {
|
||||
Self::U8 | Self::U32 | Self::I64 => true,
|
||||
Self::BF16 | Self::F16 | Self::F32 | Self::F64 => false,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn is_float(&self) -> bool {
|
||||
match self {
|
||||
Self::U8 | Self::U32 | Self::I64 => false,
|
||||
Self::BF16 | Self::F16 | Self::F32 | Self::F64 => true,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub trait WithDType:
|
||||
|
@ -177,14 +177,9 @@ impl Tensor {
|
||||
is_variable: bool,
|
||||
) -> Result<Self> {
|
||||
let none = BackpropOp::none();
|
||||
if is_variable {
|
||||
let shape = shape.into();
|
||||
let storage = device.ones(&shape, dtype)?;
|
||||
Ok(from_storage(storage, shape, none, is_variable))
|
||||
} else {
|
||||
let storage = device.ones(&crate::shape::SCALAR, dtype)?;
|
||||
from_storage(storage, crate::shape::SCALAR, none, is_variable).broadcast_as(shape)
|
||||
}
|
||||
let shape = shape.into();
|
||||
let storage = device.ones(&shape, dtype)?;
|
||||
Ok(from_storage(storage, shape, none, is_variable))
|
||||
}
|
||||
|
||||
/// Creates a new tensor filled with ones.
|
||||
@ -222,14 +217,9 @@ impl Tensor {
|
||||
is_variable: bool,
|
||||
) -> Result<Self> {
|
||||
let none = BackpropOp::none();
|
||||
if is_variable {
|
||||
let shape = shape.into();
|
||||
let storage = device.zeros(&shape, dtype)?;
|
||||
Ok(from_storage(storage, shape, none, is_variable))
|
||||
} else {
|
||||
let storage = device.zeros(&crate::shape::SCALAR, dtype)?;
|
||||
from_storage(storage, crate::shape::SCALAR, none, is_variable).broadcast_as(shape)
|
||||
}
|
||||
let shape = shape.into();
|
||||
let storage = device.zeros(&shape, dtype)?;
|
||||
Ok(from_storage(storage, shape, none, is_variable))
|
||||
}
|
||||
|
||||
/// Creates a new tensor filled with zeros.
|
||||
|
@ -185,8 +185,8 @@ impl Benchmark for Matmul {
|
||||
type PreProcessData = (Tensor, Tensor);
|
||||
type RunResult = Tensor;
|
||||
fn preprocess() -> Result<Self::PreProcessData> {
|
||||
let lhs = Tensor::randn(0f32, 1., (1024, 1024), &Device::Cpu)?;
|
||||
let rhs = Tensor::randn(0f32, 1., (1024, 1024), &Device::Cpu)?;
|
||||
let lhs = Tensor::randn(0f32, 1., (1024 * 4, 1024 * 4), &Device::Cpu)?;
|
||||
let rhs = Tensor::randn(0f32, 1., (1024 * 4, 1), &Device::Cpu)?;
|
||||
Ok((lhs, rhs))
|
||||
}
|
||||
|
||||
|
@ -41,6 +41,10 @@ impl Optimizer for SGD {
|
||||
type Config = f64;
|
||||
|
||||
fn new(vars: Vec<Var>, learning_rate: f64) -> Result<Self> {
|
||||
let vars = vars
|
||||
.into_iter()
|
||||
.filter(|var| var.dtype().is_float())
|
||||
.collect();
|
||||
Ok(Self {
|
||||
vars,
|
||||
learning_rate,
|
||||
@ -116,6 +120,7 @@ impl Optimizer for AdamW {
|
||||
fn new(vars: Vec<Var>, params: ParamsAdamW) -> Result<Self> {
|
||||
let vars = vars
|
||||
.into_iter()
|
||||
.filter(|var| var.dtype().is_float())
|
||||
.map(|var| {
|
||||
let dtype = var.dtype();
|
||||
let shape = var.shape();
|
||||
|
Reference in New Issue
Block a user