mirror of
https://github.com/huggingface/candle.git
synced 2025-06-15 18:28:24 +00:00
Marian MT model (#1210)
* Skeleton files for the marian MT model. * Marian initialization. * Implement the attention forward method. * Forward pass for the encoder side. * Expose the encoder and decoder. * Start plugging the decoder. * Forward pass for the decoder layer. * Set up the marian example. * Add some missing backtraces. * Bugfix.
This commit is contained in:
@ -804,11 +804,11 @@ impl<'a, I: IntDType> Map1 for Gather<'a, I> {
|
||||
fn f<T: WithDType>(&self, src: &[T], src_l: &Layout) -> Result<Vec<T>> {
|
||||
let ids = match self.ids_l.contiguous_offsets() {
|
||||
Some((a, b)) => &self.ids[a..b],
|
||||
None => Err(Error::RequiresContiguous { op: "gather" })?,
|
||||
None => Err(Error::RequiresContiguous { op: "gather" }.bt())?,
|
||||
};
|
||||
let src = match src_l.contiguous_offsets() {
|
||||
Some((a, b)) => &src[a..b],
|
||||
None => Err(Error::RequiresContiguous { op: "gather" })?,
|
||||
None => Err(Error::RequiresContiguous { op: "gather" }.bt())?,
|
||||
};
|
||||
let dim = self.dim;
|
||||
let ids_dims = self.ids_l.dims();
|
||||
@ -857,7 +857,7 @@ impl<'a, I: IntDType> Map1 for IndexSelect<'a, I> {
|
||||
fn f<T: WithDType>(&self, src: &[T], layout: &Layout) -> Result<Vec<T>> {
|
||||
let src = match layout.contiguous_offsets() {
|
||||
Some((a, b)) => &src[a..b],
|
||||
None => Err(Error::RequiresContiguous { op: "index-select" })?,
|
||||
None => Err(Error::RequiresContiguous { op: "index-select" }.bt())?,
|
||||
};
|
||||
let dim = self.dim;
|
||||
let n_ids = match self.ids_l.dims() {
|
||||
@ -913,7 +913,7 @@ impl<'a, I: IntDType> Map2 for ScatterAdd<'a, I> {
|
||||
let mut dst = vec![T::zero(); dst_len];
|
||||
copy_strided_src_(v1, &mut dst, 0, l1);
|
||||
let src = match src_l.contiguous_offsets() {
|
||||
None => Err(Error::RequiresContiguous { op: "scatter-add" })?,
|
||||
None => Err(Error::RequiresContiguous { op: "scatter-add" }.bt())?,
|
||||
Some((o1, o2)) => &src[o1..o2],
|
||||
};
|
||||
|
||||
@ -929,7 +929,7 @@ impl<'a, I: IntDType> Map2 for ScatterAdd<'a, I> {
|
||||
|
||||
let ids = match self.ids_l.contiguous_offsets() {
|
||||
Some((a, b)) => &self.ids[a..b],
|
||||
None => Err(Error::RequiresContiguous { op: "gather" })?,
|
||||
None => Err(Error::RequiresContiguous { op: "gather" }.bt())?,
|
||||
};
|
||||
for left_i in 0..ids_left_len {
|
||||
let start_ids_idx = left_i * ids_right_len * ids_dim_len;
|
||||
@ -971,7 +971,7 @@ impl<'a, I: IntDType> Map2 for IndexAdd<'a, I> {
|
||||
let mut dst = vec![T::zero(); dst_len];
|
||||
copy_strided_src_(v1, &mut dst, 0, l1);
|
||||
let src = match src_l.contiguous_offsets() {
|
||||
None => Err(Error::RequiresContiguous { op: "index-add" })?,
|
||||
None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
|
||||
Some((o1, o2)) => &src[o1..o2],
|
||||
};
|
||||
let dim = self.dim;
|
||||
@ -2539,25 +2539,25 @@ impl BackendStorage for CpuStorage {
|
||||
Self::U8(ids) => {
|
||||
let ids = match ids_l.contiguous_offsets() {
|
||||
Some((a, b)) => &ids[a..b],
|
||||
None => Err(Error::RequiresContiguous { op: "index-add" })?,
|
||||
None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
|
||||
};
|
||||
IndexAdd { ids, dim }.map(self, l, src, src_l)
|
||||
}
|
||||
Self::U32(ids) => {
|
||||
let ids = match ids_l.contiguous_offsets() {
|
||||
Some((a, b)) => &ids[a..b],
|
||||
None => Err(Error::RequiresContiguous { op: "index-add" })?,
|
||||
None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
|
||||
};
|
||||
IndexAdd { ids, dim }.map(self, l, src, src_l)
|
||||
}
|
||||
Self::I64(ids) => {
|
||||
let ids = match ids_l.contiguous_offsets() {
|
||||
Some((a, b)) => &ids[a..b],
|
||||
None => Err(Error::RequiresContiguous { op: "index-add" })?,
|
||||
None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
|
||||
};
|
||||
IndexAdd { ids, dim }.map(self, l, src, src_l)
|
||||
}
|
||||
_ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "index-add")),
|
||||
_ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "index-add").bt()),
|
||||
}
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user