mirror of
https://github.com/huggingface/candle.git
synced 2025-06-21 04:10:46 +00:00
@ -66,6 +66,18 @@ impl Attr for GraphProto {
|
||||
}
|
||||
}
|
||||
|
||||
impl AttrOwned for Vec<String> {
|
||||
const TYPE: AttributeType = AttributeType::Strings;
|
||||
fn get(attr: &onnx::AttributeProto) -> Result<Self> {
|
||||
let mut ret = vec![];
|
||||
for bytes in attr.strings.iter() {
|
||||
let s = String::from_utf8(bytes.clone()).map_err(candle::Error::wrap)?;
|
||||
ret.push(s);
|
||||
}
|
||||
Ok(ret)
|
||||
}
|
||||
}
|
||||
|
||||
impl AttrOwned for Tensor {
|
||||
const TYPE: AttributeType = AttributeType::Tensor;
|
||||
fn get(attr: &onnx::AttributeProto) -> Result<Self> {
|
||||
@ -1310,6 +1322,233 @@ fn simple_eval_(
|
||||
.broadcast_add(&c.broadcast_mul(&beta)?)?;
|
||||
values.insert(node.output[0].clone(), output);
|
||||
}
|
||||
"LSTM" => {
|
||||
let direction = get_attr_opt(node, "direction")?.unwrap_or("forward");
|
||||
if direction != "forward" {
|
||||
bail!("LSTM currently only supports direction == \"forward\"");
|
||||
}
|
||||
let num_directions = if direction == "bidirectional" { 2 } else { 1 };
|
||||
let hidden_size: i64 = get_attr(node, "hidden_size").copied()?;
|
||||
let input_forget = get_attr_opt(node, "input_forget")?.copied().unwrap_or(0);
|
||||
if input_forget != 0 {
|
||||
bail!("LSTM currently only supports input_forget == 0");
|
||||
}
|
||||
let activations_default = vec![
|
||||
"Sigmoid".to_string(),
|
||||
"Tanh".to_string(),
|
||||
"Tanh".to_string(),
|
||||
];
|
||||
let activations = get_attr_opt_owned::<Vec<String>>(node, "activations")?
|
||||
.unwrap_or(activations_default.clone());
|
||||
if activations != activations_default {
|
||||
bail!("LSTM currently only supports default activations ({activations_default:?})");
|
||||
}
|
||||
// activation_alpha and activation_beta don't apply to (Sigmoid, Tanh, Tanh) so ignoring them is okay
|
||||
if get_attr_opt::<f32>(node, "clip")?.is_some() {
|
||||
bail!("LSTM does not currently support clip attribute");
|
||||
}
|
||||
|
||||
// The shape format of inputs X, initial_h and outputs Y, Y_h.
|
||||
// If 0, the following shapes are expected:
|
||||
// X.shape = [seq_length, batch_size, input_size],
|
||||
// Y.shape = [seq_length, num_directions, batch_size, hidden_size],
|
||||
// initial_h.shape = Y_h.shape = [num_directions, batch_size, hidden_size].
|
||||
// If 1, the following shapes are expected:
|
||||
// X.shape = [batch_size, seq_length, input_size],
|
||||
// Y.shape = [batch_size, seq_length, num_directions, hidden_size],
|
||||
// initial_h.shape = Y_h.shape = [batch_size, num_directions, hidden_size].
|
||||
let layout = get_attr_opt(node, "layout")?.copied().unwrap_or(0);
|
||||
if layout != 0 {
|
||||
bail!("LSTM currently only supports layout == 0");
|
||||
}
|
||||
|
||||
// The input sequences packed (and potentially padded) into one 3-D tensor
|
||||
// with the shape of `[seq_length, batch_size, input_size]`.
|
||||
let x = get(&node.input[0])?;
|
||||
// XXX: depends on layout
|
||||
let (seq_length, batch_size, input_size) = x.dims3()?;
|
||||
// The weight tensor for the gates.
|
||||
// Concatenation of `W[iofc]` and `WB[iofc]` (if bidirectional) along dimension 0.
|
||||
// The tensor has shape `[num_directions, 4*hidden_size, input_size]`.
|
||||
let w = get(&node.input[1])?;
|
||||
// The recurrence weight tensor.
|
||||
// Concatenation of `R[iofc]` and `RB[iofc]` (if bidirectional) along dimension 0.
|
||||
// This tensor has shape `[num_directions, 4*hidden_size, hidden_size]`.
|
||||
let r = get(&node.input[2])?;
|
||||
|
||||
let get_opt = |i: usize| {
|
||||
node.input
|
||||
.get(i)
|
||||
.filter(|s: &&String| !s.is_empty())
|
||||
.map(|s| get(s))
|
||||
};
|
||||
|
||||
// The bias tensor for input gate.
|
||||
// Concatenation of `[Wb[iofc], Rb[iofc]]`, and `[WBb[iofc], RBb[iofc]]` (if bidirectional) along dimension 0.
|
||||
// This tensor has shape `[num_directions, 8*hidden_size]`.
|
||||
// Optional: If not specified - assumed to be 0.
|
||||
let b_default: Tensor;
|
||||
let b = match get_opt(3) {
|
||||
Some(n) => n?,
|
||||
None => {
|
||||
b_default = Tensor::zeros(
|
||||
(num_directions, 8 * hidden_size as usize),
|
||||
DType::F32,
|
||||
x.device(),
|
||||
)?;
|
||||
&b_default
|
||||
}
|
||||
};
|
||||
|
||||
// Optional tensor specifying lengths of the sequences in a batch.
|
||||
// If not specified - assumed all sequences in the batch to have length `seq_length`.
|
||||
// It has shape `[batch_size]`.
|
||||
let seq_lens_default: Tensor;
|
||||
let seq_lens = match get_opt(4) {
|
||||
Some(n) => n?,
|
||||
None => {
|
||||
seq_lens_default =
|
||||
Tensor::full(seq_length as i64, (batch_size,), x.device())?;
|
||||
&seq_lens_default
|
||||
}
|
||||
};
|
||||
let seq_lens_is_default =
|
||||
(seq_lens.to_vec1::<i64>()?.iter()).all(|e| *e as usize == seq_length);
|
||||
if !seq_lens_is_default {
|
||||
bail!("LSTM currently only supports default value of seq_lens");
|
||||
}
|
||||
|
||||
// Optional initial value of the hidden. If not specified - assumed to be 0.
|
||||
// It has shape `[num_directions, batch_size, hidden_size]`.
|
||||
let initial_h_default: Tensor;
|
||||
let initial_h = match get_opt(5) {
|
||||
Some(n) => n?,
|
||||
_ => {
|
||||
initial_h_default = Tensor::zeros(
|
||||
(num_directions, batch_size, hidden_size as usize),
|
||||
DType::F32,
|
||||
x.device(),
|
||||
)?;
|
||||
&initial_h_default
|
||||
}
|
||||
};
|
||||
|
||||
// Optional initial value of the cell.
|
||||
// If not specified - assumed to be 0.
|
||||
// It has shape `[num_directions, batch_size, hidden_size]`.
|
||||
let initial_c_default: Tensor;
|
||||
let initial_c = match node.input.get(6) {
|
||||
Some(n) if !n.is_empty() => get(n)?,
|
||||
_ => {
|
||||
initial_c_default = Tensor::zeros(
|
||||
(num_directions, batch_size, hidden_size as usize),
|
||||
DType::F32,
|
||||
x.device(),
|
||||
)?;
|
||||
&initial_c_default
|
||||
}
|
||||
};
|
||||
|
||||
// The weight tensor for peepholes.
|
||||
// Concatenation of `P[iof]` and `PB[iof]` (if bidirectional) along dimension 0.
|
||||
// It has shape `[num_directions, 3*hidde_size]`. Optional: If not specified - assumed to be 0.
|
||||
let p_default = Tensor::zeros(
|
||||
(num_directions, 3 * hidden_size as usize),
|
||||
DType::F32,
|
||||
x.device(),
|
||||
)?;
|
||||
let p = get_opt(7).unwrap_or(Ok(&p_default))?;
|
||||
let p_is_zeros = (p.to_vec2::<f32>()?.iter()).all(|v| v.iter().all(|e| *e == 0.0));
|
||||
if !p_is_zeros {
|
||||
bail!(
|
||||
"LSTM currently only supports default value of p (a Tensor of all zeroes)"
|
||||
);
|
||||
}
|
||||
|
||||
// these all have [num_directions, ...] shapes
|
||||
let w = w.get(0)?; // w[iofc] has shape [4*hidden_size, input_size]
|
||||
let r = r.get(0)?; // r[iofc] has shape [4*hidden_size, hidden_size]
|
||||
let b = b.get(0)?; // concat of [wb[iofc],rb[iofc]] has shape [8*hidden_size]
|
||||
let idx_wb = Tensor::arange(0 * hidden_size, 4 * hidden_size, x.device())?;
|
||||
let idx_rb = Tensor::arange(4 * hidden_size, 8 * hidden_size, x.device())?;
|
||||
let wb = b.index_select(&idx_wb, 0)?;
|
||||
let rb = b.index_select(&idx_rb, 0)?;
|
||||
let c = initial_c.get(0)?;
|
||||
let h = initial_h.get(0)?;
|
||||
|
||||
// w, r, wb, rb are all iofc but lstm expects ifco
|
||||
// so we need to move some stuff around
|
||||
let idx_i = Tensor::arange(0 * hidden_size, 1 * hidden_size, x.device())?;
|
||||
let idx_o = Tensor::arange(1 * hidden_size, 2 * hidden_size, x.device())?;
|
||||
let idx_f = Tensor::arange(2 * hidden_size, 3 * hidden_size, x.device())?;
|
||||
let idx_c = Tensor::arange(3 * hidden_size, 4 * hidden_size, x.device())?;
|
||||
let idx_ifco = Tensor::cat(&[&idx_i, &idx_f, &idx_c, &idx_o], 0)?;
|
||||
let w = w.index_select(&idx_ifco, 0)?;
|
||||
let r = r.index_select(&idx_ifco, 0)?;
|
||||
let wb = wb.index_select(&idx_ifco, 0)?;
|
||||
let rb = rb.index_select(&idx_ifco, 0)?;
|
||||
let vmap = candle_nn::VarMap::new();
|
||||
vmap.data().lock().unwrap().extend([
|
||||
("weight_ih_l0".to_string(), candle::Var::from_tensor(&w)?),
|
||||
("weight_hh_l0".to_string(), candle::Var::from_tensor(&r)?),
|
||||
("bias_ih_l0".to_string(), candle::Var::from_tensor(&wb)?),
|
||||
("bias_hh_l0".to_string(), candle::Var::from_tensor(&rb)?),
|
||||
]);
|
||||
use candle_nn::rnn::RNN as _;
|
||||
let lstm = candle_nn::rnn::lstm(
|
||||
input_size,
|
||||
hidden_size as usize,
|
||||
candle_nn::rnn::LSTMConfig::default(),
|
||||
candle_nn::VarBuilder::from_varmap(&vmap, w.dtype(), w.device()),
|
||||
)?;
|
||||
|
||||
let mut lstm_state = candle_nn::rnn::LSTMState::new(h, c);
|
||||
let mut h_acc = if node.output.get(0).map(String::as_str).unwrap_or("") != "" {
|
||||
Some(vec![])
|
||||
} else {
|
||||
None
|
||||
};
|
||||
for t in 0..seq_length {
|
||||
let x = x.get(t)?;
|
||||
lstm_state = lstm.step(&x, &lstm_state)?;
|
||||
if let Some(h_acc) = &mut h_acc {
|
||||
h_acc.push(lstm_state.clone());
|
||||
}
|
||||
}
|
||||
|
||||
assert_eq!(num_directions, 1, "if support for bidirectional is ever added, outputs will have to be concatenated, not simply reshaped");
|
||||
if let Some(name) = node.output.get(0) {
|
||||
let h_acc = h_acc.as_ref().unwrap();
|
||||
let h_acc = lstm.states_to_tensor(h_acc)?;
|
||||
let h_acc = h_acc.reshape((
|
||||
seq_length,
|
||||
num_directions,
|
||||
batch_size,
|
||||
hidden_size as usize,
|
||||
))?;
|
||||
values.insert(name.clone(), h_acc);
|
||||
}
|
||||
if let Some(name) = node.output.get(1) {
|
||||
values.insert(
|
||||
name.clone(),
|
||||
lstm_state.h().reshape((
|
||||
num_directions,
|
||||
batch_size,
|
||||
hidden_size as usize,
|
||||
))?,
|
||||
);
|
||||
}
|
||||
if let Some(name) = node.output.get(2) {
|
||||
values.insert(
|
||||
name.clone(),
|
||||
lstm_state.c().reshape((
|
||||
num_directions,
|
||||
batch_size,
|
||||
hidden_size as usize,
|
||||
))?,
|
||||
);
|
||||
}
|
||||
}
|
||||
op_type => bail!("unsupported op_type {op_type} for op {node:?}"),
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user