From 695838432747b9c9460e74cd3f5086642b6897a9 Mon Sep 17 00:00:00 2001 From: Ogundepo Odunayo Date: Thu, 9 Nov 2023 12:49:17 -0500 Subject: [PATCH] Add support for TrOCR Model (#1303) * add bce with logit loss * add bce with logit loss * remove imports * fix tiny bug * add test documentation and refactor function * fix test cases and formatting * add trocr model * fix formatting * commit the actual model lol * more formatting * remove tokenizer config --- .../examples/trocr/assets/trocr.png | Bin 0 -> 36733 bytes .../examples/trocr/image_processor.rs | 154 +++++++ candle-examples/examples/trocr/main.rs | 132 ++++++ candle-examples/examples/trocr/readme.md | 16 + candle-transformers/src/models/mod.rs | 1 + candle-transformers/src/models/trocr.rs | 434 ++++++++++++++++++ candle-transformers/src/models/vit.rs | 45 +- 7 files changed, 767 insertions(+), 15 deletions(-) create mode 100644 candle-examples/examples/trocr/assets/trocr.png create mode 100644 candle-examples/examples/trocr/image_processor.rs create mode 100644 candle-examples/examples/trocr/main.rs create mode 100644 candle-examples/examples/trocr/readme.md create mode 100644 candle-transformers/src/models/trocr.rs diff --git a/candle-examples/examples/trocr/assets/trocr.png b/candle-examples/examples/trocr/assets/trocr.png new file mode 100644 index 0000000000000000000000000000000000000000..06886aab121616bdfc526adbae2e66e9f3becc50 GIT binary patch literal 36733 zcmV(`K-0g8P)CrBC`~$)+WLl58tjvaK9U zMT#=ToC8RZ0FeWcfepasG`l-_ZoVDPsrvoVeP=dc2?7M=pDfn1&$Bag`*xr1KBrFA zTUBqNWflQM1Vj;GKokL@7yy6}5LlQ+96$g7AYn$OD6x#1gf)l&0)|LL2w;c^n3>aG zMF0Siz(OhrfPjDz0)l`DAOa##_PId>7DC0$Ktu)sMOcjEpfwC#&qd3qKm<@JiU0&8 z9V-J;LwM|-h6nyZ?re>Uh8>LKtMp$E1bV4ItBqmLW>|EEQ09_ z!*pq&6#$@YLLz`-0AN8%*AyA(9nBzsKqwXfiPBp@NtY+|2E>^TSR8P=G@&E}&{o&Jabf|5O&Gv5=4AONOE0|Myu>%IFDM_7R6mG4oq ztA0=Zuas1)fBSk#3kV8GZv_dab8T^}vg!PVOd$XRl7I*yqV%>=ByiG|DkKa5M45WT z)UFUXMI#CJHVG8$*#tccVR1pHpF?kBB@|&ma76@>z$m4}TI;1ZoNnd9gqZ?B0uT~3 z=`Z+Nr@Wl;Gzn5eB0JyWM>hLqY5V|!(6c=tK|Iv}fmBmTy1gT1XXJBkp1tFH7wsKK zy+I&}oc87Hxtw)BfT?bJ`Om*6|FVTon}x3|T8jvX&Ax)5KtNCtXY)7PNVDY+kRTF) z086(0K)SW5bXUwwP!UIj0Z0&lWN}}FMN5Jgw^czoWlv%hU?JpmnF(b}O)VNG3@An{ z5TKH53L^lfyJx0W(jk)|L$(6(oNMbeIde%+YEdoz-ZNj)g?rlfPFqgBad38*nQB^8 zCOs1pge85zsqu!VSQ|^@!KLFfBc3x0KxayyQx9O#)mnTmXWbvl()0em-ctWpCz8$n zZ&$TzKok%W%6zv|eGmZvC*8|{dJ8q($0YR;5KZdjr`x0m2oZuXYi7yjGa#L@uF~%S zWXt(}ICEo`%&`Lik!+V2;8YC>p7LJ4x17HxuU51{LLgb( zcMJi5*d9QHB$zG$0TD>2nFx!UE?&wM$Kv+Nj0nP#?nEMtblR52nRT2QDLoi&ae-lm z0Z3}6B4?W(_Wnu=MUXrYVd<^;)Yky$fqE!RsRD4M2gZGS+ufNU%-PQPlE;Ck|M|3I zZK_X^P((n05aCn}mD&6#NYWcNVCuuaoCT2mh}<*%GpixPe1H%z(~;Tt(px@r{Fyh0 z=lcA6@-JJkhb4WRqGi!2

~Mrz~e1Jg}s2HbXX%Q{xwyfLRE1wo}v2P9XqEv9}ai6Gi|@$Bpvsdqt;< z&F<|~;od!*GJ4xX_fNUU6az>hQWD`)br5B!K10s~`qsW1rWh!uS|fGMdSl6KoD>FK z{D!=&%d?~jU;gv&$-iu&3`_cU=es6M1AQ(_YEPm<;hvR_V6x50q?WK~KtMuB2-yfJ zEu6-~0JDkCkQxqzAS@aQ<&=#^o$gYfwbM$50iXZ_^)P)xfgUKpoQR`BB1E7?a3RG2 z7o%>_gEW)Wl)=}o01ye6q7-t47AMFn;d zlgPj?A_Qw$6o81ZIN9e8GXO9HBLOJ4iwK+=3jzS3v;UnlZ#?_orfQ2AFdcWMxEAvg z6pCR|imVZhj3lwR8dHHQ1RE0Z<7G5D`d%OSuE5-h(g+KpN8* zNv#n8fP@7J06U(tpt$X3PD4)tWDf$*`63oqTrXIV4LEbYQ9uY0VvPg@U=fu73jrzu zK!acrh=fneGs>27FDFJM?u8-J&>&b5V4+u>{{l-36tm!LdVbCvhf{}}S{;j^A*2S; zsmrmaAbN2f%FK}O$-hp~i!fX^3F!3YFIE9FD6)P-6(Z?{>x2*(MXXA5(-MPMHc{d4mjVwcq8!Ddz*enIJd zmc{@xy>QMQGj$=L_i`kAP#}_q^Lydjr9`ST@d`ORWC5^vBKqFq|5q)V5P<~o6!1y} z1avl$5=6_05a$)^1iCtdkX-<4SU>?22%tqw5xOPEHeDf*=D_wc0JC5?s%#M~PFxnT zmbs^qU}-!60LTD@h+@)l38krVsF_7XY=BH`$W0)NtoQc+%#HX~rjURL3yFw`$}9j% zPmjz3irBge45KwV-Fz9cm90T60HQ%mZI?_K5MeR<7g+Z)2(#OBQeYA-pb+3W$0w&O zQA{D8o)xf&i=4KFXW>&h<1C9hWGSBj1O=2VJ>>7nKfiptruuhFG$Hoj>7`wF5otxZs#!ydYA5CJK*Jw*VlLQJ>JEPXFU;*dm8788wnfO@tR1VsU41p&P?37o!M zvycG;6C#LYK1Svl^nwr&k;h#Kh)_t7xTgWKWQ242N=q)l8Fs?y$q9tkG9eMLWg^N9 z7+^t_hENbsJ-NvgfSa8qQ?S{-EvV$&fZxl>l8VT7?kswno)7U93_yAg$hQ*8o@@FB zzZVYt*D6^`Irfr6L1pWpw=lFuRGQ;zqZGwH>MW^6I!X}$p|K*sz100wB?usKX8&i4 zQF<6a1~(vQL2&|>bOb0 zmT}m#i?a8WVWUgw!KXc~Q<4kut0(OPz&m8x}lMK}kg+0EE&yqU!!Dt(Bo)8~249h)-a0v8`mrc1JL zXV0QoLKDwWMTjrE&gB%HbY}YT_vBxXXhvjaASa6+CRurDx~%~KvsIFW-+XXzrOTmb zv(@JkfiSZsK?GEw2&^!pFf%IP)IDHUNF>NT44{_**ux02L_I_lvBFH0S`ufQ>WlOG zv@js0i&l(o?34Am1c?YpQIr`nz3|tW;j+_25R52b7U-SV5D>6sO$e4LUC&^QQKSUK z#-w!mlo*g`vCOw8N$(bW(GLVd#9oykk{}>TPcxl6Axc05d#dIXb3z0p#Y|<%tjK)) z(+ZeQD?wPq0D9vUgfkTG_vBxbV2237G9wYNAZkuAv^cBhrb{&`43XljW#l~&qa?LpBwGjbGtAx;6QJMikII%tm3n*ns0ka_Xk_A~vfi*hK z14w7PFf@iuq?5u-PJk2&$7x;;Xl6hZrl=UVN?4Ql2T}dl9kyUCX<~y!C_qR6Vhm{= zq6BfbcnU880Gy^61O&hU5dm8g03;$yU0=1uBf~njL6@{7qBKP4D4UQc*7!&?S83Xu zXN+<(6cra0!eaPS7I%8Y#aldc<(+8~_ILw&;aA64P$ayB-!CG5n7J;!^PHA>)HD5c zW@w#d3B8(G^D@#4ZhX&@`1VB3`2SpqYD<>gt2^(x{=S%}lFbjT(m);AbmCIj1C=tX zJYi;Cs^>!9T5y#PeWKI$I!__9ASq70dRsLMSb?U~;};PYK(OhB6{X7ov$Q@Cm2PCf zsTrymQG|eP>KwyxoqdUGgE0#NGwaM?jRg_da!#|#k$4NAiram`(^td-06>lf20=v2 zfaC&%KE_r#$tz9#^mplH)^a!Rvgry0R?rpG$R#$igfi%uHU`klv5mwAeF-yLy9hmI zRgdXv!xSMzK>(E*-HVI%6~v~xX)%u{RiuavyofrcxWKoB64LyRMXm6<^H_XO{)LK8 z%j}WToN~howC z_fOYuP>NWHY!>KdX4FbJ2>`meAUW5A=!r_>9za=wwYGs@o>zkWy(Fv{yilASn?^`coRW)2HPnyz)ACu#4MxX(NUJT&%mD4#|hFbJUL#F(cHTv{-p zw%Vk8Sc}zUEY?E08KVeN$5HKkD6RnnO0jBGXvHF;AX0)zoUV}+x9?%QpmcipaXJXVx|oD=J(z}3M?xnNFX4L3Ty%>ZqhCY zYdq!ldzU&*dYCyv>Ss5DJZYITBlEoQ+*A7#hm!EaX!jvQb+)d zy>m$KB3a23E=Cq*>H24F+Gj7!(@}n_D9sICEC#~U>H|)9A(rlrg&7F&)z*uDyK;&W z`kuTB`R7uZXI`f_KeQpD-VP&x8ETDb9Xtp`AY|tcmibMXR}z15h2kU)9lrB=JxUdxaX1(3iA?qhfZktN0A{6i zHmeWP)-s|ZaE+TNHFM7#*sgXE3 zjssOb`jKbnqp)2G9Reo$~iA#F(R^9ilJdq1Duu=*h^x@jNjyy zM5R@Oi&amO#s2BLu!z4W|N2DtLdJ~1oSOfDB8wb+k{~xi(kQw1R8TCq z4zmc0<;684DR3EUkXZAML-Det?X?FJA|h1OOQ2^U1=>a*n+vw|8MAofXGIxK`y~ZA z5d`|0wo`TwPu};yG~|jAlpyJ{=lRfY!RfOvest~ItDzMDC0Q*+7NW}(z$y6;kd8;> zlCyh)^_IeWDXk-=d;+hcUi{k?>>&}a_lf^6 zm;ZAXKsrA(f-nkVZ(=%-`kXyN4nF~DztmcDwGNn-XSJ{po>H~O%pf2nh@7xY{L}II zTJD;h6#)=IizJ+?5^II1$44gs0!Y-W>rtEn3*5s1vSCk6wvyZ?S}3J)FLqMf9Bc|L zKoZB&@fN;%I8ItTSlyJVF(ypw@Yo1zu1(+57?>|8_Z1&hs`f}yO@qQ=IUn!-yZwD9 ze3Rei4s0mqA&wI>Q9E{YCRX&+(Mv{|V{w!KkT5~I5IMsDdRd5BWfW$S_p>Mx&N-gN z)Tiuu3nLp$^MsX+Rk{0oqDerP4a8Rn0SN?GhB ztE6#y6_Z2l@|sJ#rO1{G2_TXtkf?VaHewZuW?&mxd3;9>X4?g@8Om=o2wTjGS&29W zt1})jB$1TE11+PYKnww+q@e=aF@-CLP82*+;Y39qu?1kTt{ znp62OK$dJ1z_rZOAp}qHiH0Zm2*Zx4gi3{M3BW0IV&NS z#Y>#?Ja#b**yA4jo_tRLG!imfMi3zvY)FcYC}a@JqO_JyU+o{}S6?%bv~x}^QKA*b zKq}=6U?vwMc9msQoGj$$_WZSH9LOIX4&p?)Dvlknm_mC(1d=#Cj;z!YFwK1AOW4@TteN~EV0u^?&K6&2YmB?v7GDMS_)B%+iyol}ZNp&odOi&d28 zkaM|4Izq>Q16C5PyOK9&eJ{l{q+5iJhIT<4agFip8OP})L06e5piZt!4xM?PGetLb z09{nyErmrO%s@&=5^_xFXNh2E9(VReak{)b$hH)DJy)FH>PEFyurVaA1kIdE(zkui z^7V6F_mzyV-sjm)lG1S_E-_eTJa;Lri~=Dc5mL$-={P5A#i5e1Ag-gB5Os>=BO-}K z0ZxQ-O00kKZ=Z7@%=_mr>jI#}Ak^CdDa+tfX%YZsx%LoAJm(uw^#^1>IVoR(CBHfH<)B9IPZ$JF{k zXcWs`ASAG0DWeC?&KG+ab{?#a>mbMNOL9lk>N{R6X<+NOgT|8wJ2=&G!B0HnxKUR- za|4Cs*==qY7P$zr=OEbIp+Nwk5zQShhlMMz8Y)CB$2K~pZs)U2cchUh$M5LPH{YGS zI8h+c00^QMLw!ATPG_(}tQ}xhuIYt6d(UIhRXF7pWrVDT86)?$dpUdND> zfD`~#$`c^4r~qCEL4vO18Zw^CUCO`Qfye#5uvNXxumK=vfL&!wuI*Tg z+*e$0IWRFm80Ef8cv)JB%2oq{HFUO*_?=(jWuV?AnIC|;{%Ux5$LFW3gBL~RM68hGFRw?CF+5$u&Fc6s_ng7B( zmfX4AtrvI=Cyni;AhsZwbmNC4Glnby09!_@5`o9|nA(b)pswazW{Z-N0RxhTCyF zHCdY4`TW=3f3+W`vvdj6%m7N2{!MZXM}~OY4a$>90>>7dgHO(t*44^NIOx>i7XDPF z5!7>8ST9@tD3XnpQ!-F=mEw&4IFp8PkRJ#HhnG(h(wTq zsfZ{{XIc=iNjzwq00byxBRL7rq*zPUkU#-=$L|~sN7j34QMsLVo^F(P;s~6mIRdSI zm(M)l%<~%KYO#Kd-<=QxV#-csL^){LzTTI3C9G%Q__D&`!olBLh|~)o#q(Zk1FKg# zp6tg{Eda@8X2vFgBtQz-<|2Lk*ru&B)j&syN#eW+P)h5^000caVg=lVp(E`N1@)19 zwf~kpz#U_;9Skk^(*mKS^X-ATawqCM&?ZAq` z_YL%!kjP1Ij|D`)=9-D1ESRCo20UY(X&rdU!il|)K7D*-`T9brGYz$98i_0h&_n?w zoW8;L;^_p7gE$>DSX}z)$GsP-IDKyfl2R0bS|lWd*8%iL>qA@RL5}&O2oAoyPcG7s z7ru$rXl)uRKAY(%^q}MT)6e`}GkkBg3}NsJ*Cnv+if7Uml;4Heg@~XM0?&iksI!Vt zzw`6M858ylsrO0!*z-?SR4Ih+G%Aaf<>7pH>5RVZ zCK;<6LW#9+kVH&cFpL7E5>^UOY3Jy}U-Ku|ZG6aFS}t_vb0U-hV!b#R0Q9;BLa$N{ zEgIN&Zzs2W`w&nLU8~3ejW(UggtbDu*xzi4n{I?4*^b@2kwTf1zVTlk3f$Y)ydkNx zN^|Ei>&^?BN{QJSCON;Zh_npOot)b7{DH}Q)0r<+`{M)!gimc=nt|?FKF`8{yH4sY zn^z);Nwgz+BnjTLt`Jcnm#~=mN4&Me)(8QmDHQ-cTMCgejk_Z-Fp~hP(?|mV7N5nb z@1=(tpE7k$pF!B0SW8xMMnIk+4hWL+he=v=advt2s};w9=?uV&v!AthUY+NezB4|T z%;MqaKRj=i&96qHvzT1~0P+yBZan3N}%Mcmv7yFpxl4HCaS&&iiY}uD|J%;l=N1blto`6eGRvxM@#8 zB*27FP%`>)BnUNx&wj=!-n|12YR_J{deT>Lle5(t;sN0EkQFv9)5i zLwsH`98qW$Rx+Mrg{66p}*6l@Q&eHSkaK?z@H^6Eq}1vEj+P>X-C6jsB1XL71}%B#4xC zG(iCtP)NY0N7uJx9G{{LPm^A`PEP@xIU};Hf~6uMwvzD;0FkB$Q!A(!2KWxDX9qmpad zaq&v-156ZwpaSXf`sb3^C`VyiD->%b7se`}s4L|fPWLre-Ow#|(S*JsU=U$t1vLWL z+=ACE1cGtiX|dN#ZfU+Rp+q1s)>=V@Y}~AK$RZ=s(%U#K51{w=Q{uKMmXU#)r@Tmk zQxyYhAn2Kn)~2-xS;Yk{z4%^25T3)^Jl7iIdfs{rk0-%s;5*&3qJW8SNr!Q1c4t3r zJx0wS76lCPnG=~o)hdkx!wZ|&14|`2^Lke4gaNfpT#K&J3su_xl^4{yb>|q@cPbQg z9N{M(d-_MlOUO%}*W%}hl33+NUhLM(k$YxTgmxSTc`6M(FDB`_!zO)Y$b3rjw{# zH`EJHPp|mpiq@zgF_8HdM$$oAm8(M7Mderqw5sPDQ3{o(C5VvishQm@h-kA1*-ogq zqvdr2JNJdVVwL0_yKZvpd8H3T%@yyBbInq^M=yeDDmA0iR$m|hgusH}8}{Q4#qDml ze4WXu`e4j%I2j+RZ5^r)hu~*nOGr7%*y(EP*GZf-FEHg>01*B1@8uPxS3UWy4?5kH z5y)QZ%SDzJBMjfBj6)1}J!pC4oy!gR-7{@>pTXKzVc$pH*0yWzi^q`YWfN*bPpQT? zXBwB@I4S`|XFU(=g`K`H&$(MRE4DF))y~Pk-!-o`;dcohZ{?i0^9Ro!eap(cHheb7 zkadQ#8lpYN=B^rjrDC#gS8}aFSb%oo|Jh5=y=&|TO!SQj-mc2sELr`#8VLY^lr}ai z2ZqH+6Q*_k2nrJ|-J&$72CS2BTBm?*^b(oTLa_iKQOP95qhDK~wac3>6}xdXLB)S+yYCy1JV!E(3Vwnbz@r+H z4+RA=5s0AfH41#Kk~;L~rVr0@tCu1Ia#s8>dv16Nn)b>A>dnMXF0#8n^i7{eF3z=m zeU{~}tP;u7p05UDKep@zynr%>_%n}lA*mhN-B(<7bx<$9j6Tc(7Z3hUWAL({T^2gv z2&f~EVrk{ouQ<@xC;?GBdc0M*ezk{aWB2gjnlhVTH z9;t8sC4jj0wW&4R0;i+1@{#X$G=fu%h)IN5ILpX~v=fy8TTBemvHLzg;RXqsZsGvC zE=laU3IRkE0ufk~SIS`~YCKym$MN*&6{;S)x{V5v2#}BnkqA?sd_Wc?LJ%@t;{=X0 zOmM&D^P6mq5mLRB&kWB33@L|&{whZxqLfvb2?%71KFz)b$R>RXfpL%NvVo(bb^;`8=tE9c42n*x;LK4J9%p@Gm9WnI@t6?JsbuRgjoac$Ei!;2c_)rw?-Sv=rOS$rNT)QXG2 z2uS}CDI!I5+E-w4VSn30$~Beq;Gv?ETi#-_#y|6OOTZ93{zz$P<5jgd*G}0U&py2t zM<4Hme8D(dv1gqI_?UK{jw;Fyx}`v%J@38Y9^bI_f|_=*FVbmZ zfe`heX#oM$y5+bU5v1j~(qiS<`9llKN7oT)f|OB97!W|SWk5yHr9^emL>^cZ$BU_B znAQ?zyxM6UJ_Dee?zm{vARv-95AW69wcmd^#7>@9POgo}+4Hx?-}t6nKI)(Ifdvs@ zu_q#=E2oFTTd^!aC|(#wGVxHea_5EU2z0B>#C`Oo#_$cU(ZFsO(+&#~Yn#ORmf9T! zYpK*`v>?3HbMcky5no<<;lXcp#8U6zZ*^GDHba1y(Z5--EVq>-B{sVBZH`H4O9SiDZvH)0-wMt?28!ygnx^b-Zx>*1W(t$5O0BbkD z>1e+62hZluyH0GFe?`iw*ID#pu9BqQItc)Qh(T>0ylV4Iosf&iGCQ?~VV!;TM@oj6gafFqJ|3qhv8vMn1@vI3G+BNUR8Up$yU{|z!; z@=UuKjl+Qzq4Q*Z?%yoyqn79SB7#KNO>+#8v&bc8Q8+8Wf`ArW>+s}+m)l$}c0B1z z9N6{2quo2-5)VXvE9WDE2s0{26gvxF|9o-aH`}YzajM?gUi}MNTB7OJeonASdbWL* zN03${%VI~S(~n=oB424pz3az?yPn)nGkwdhbAkQ%<&Jcdm!4~!cSYac)-o`?cYdZY zv>$j}UHi_J9Y>j$+0z^Tu^kJQ+|O=44+J1m6JMIyxZ!7#F#lchi^)mCC)N0Sd`g6f z&(Z*fNJv3)=un@xf`5#_F#AT`f?nFY~n&@ z@RC^Tn(p(vCY-sKhJjEl5en;`_x&1Qm;j=H$T`D9*y|sX{kr(QGml$*;os(UQ&HVV z_c=l5jT=iUdAZs)GBy`I=;k+#Rpw7ltzH&V#x8T_`AW%&iaEa+hn|RQ{WDjbCx`bY z-r9?cT@gQEY#o|f)z{U<@79uNdnfl-IA~iU^K#Kr6bp`O%}+&rE7nEs%W1yXBu*I+ zzd2dB{Ib$yU-$l%lQ&;Uu4vQi6Y+Xo=UhT*VjXD$=#{Cbm=Oyj2nR`{b|9u?Ys)Lo zH=h1LaP2DQupT=sXbmff*i}|2vdT%!)DE02s>Jq0))KK@tKIihBtw1Sl-r z1rT&x=VZ}KW|L}v^!X>%)nf~}z#0}#l73+eiom3mnl_Gi|Ib=`aOQ>WV=IZE(=!<(++vdfu8w-+h0Xv|Jo8sUkVSCd<->iiJkNSZbay|E41Jfp zA=W^NN!l`pP^VRSOA{;cRFVMjk!AELbj$CXAENnr z`p*?pov$wN(0n0GRMW3ETfQ}(*KE$-zxRS|>z{kFTbg~>?IR~P9L;T*;^M+kGM6jv zTT?4EDy?#(*eGHv(RzMh|BJ1KABszU!)y82j zb*e382TsI&(?|cnEcA7Q$}WS(qUva^Ntg|~04X_<0t8 z(kYsT>N-JabGd!vUpiX&zKa5|5}Sdqf2h;BbHH02w5QgsY`cryb}gjcQ7l_P5NKD~ zZ+v2xduwqaa&V?v`}Ak!*w!Duv_9N*tb0t){o?Fkd+ef>Q^n22j%weFVUj%6`49b5 z#X`MmUGZ9hggJ_c7*!}aXsKYKn7C}+nl7~pXZn8Di{b)DIV&J+fZW*7DZ8^%^9Noz;glUpG$j#!xINYHZn!BB>qFwRWgjC@ ziIXWN-CVZk{+UkxrtyhSMxCu;kem)>FXkb*Y@hwZAodAdPW%|$1VlWe+8&UCMj*RK z+i%AaP9kmV)#o3tFL_^Ie!lae=lvhL$wi;L)EmG#A&*yz z|G$#-7CvVnknRrWy2xpXKl&Ft7-yClB_W*q%p6fGv`)d->@pwQf8=PvmKRKS$gRTZ$Xtj!tg0}0$jY0sRv~Cp~D0TtDGq)Dog&an)?vNrICu|3; z)crEdxGp9{suvE8i4?HqUOcV0z`xD;@XDfW?6x7^_dw(nN3TRT>7phDUMz@8r9~rv z5D&xAd*}0S@Y;>JpiJj1-XRF)>J^P-V?iS5paLUI7%00~T0j$A7;~ ztE={8|BE}4!uNGFArgJpi;+&j`vCy5Jb+Uy02IqeEY___L0##U;`!(=kJMpUHR=O> zJda=d^uhmdWqpPACSC7OJ~Mn=>q>d|H+{3|)-6*eXs(>z+BkOXPdEMes;EmNTG4!~q`Y@AR-{?OCNbIFJ14lJ`(=qN0H zdRAS(Y{sXg*ueT{v@so^t!+QZ&4iqyJC&b!s?=&JYTLN~*a!At{rY!Yt|Td*42tu)k&Z!jDAXme zQO-AAsgvU+;J!?}TtlcT;&x_^?B9QMA=prP-MkEf>ghPbz$S3xXU;3`x^oc_QTk`5O>eExiF{QQr zqo{}hl>#ntBFeF16l&jC2)fTjzc~&LIMN_E@_8$_6+P!{1^FG{7a8S?IETM{u(J00 zfu$FJ=6JE|I?;V6^(`y>(0d*9s%1?A%8q~Ppj`Ow@m6l;KXh`x*vx4Hgnx!=L_s;V zs!DU(|6&PT5@rHn-+Y}~>n?DPkKyq>&Ca@Q!^Npu*`JyHWV`>#Yt|o9mX7RtQH}KY+=0eSAuUl<5AKf=#O1Z<;?#`FK=^;!w)qKD0ggUWp z7bc&bE)^>6#7TB_<34E}Tqx8I?-@EcYlEttEQFJfj!n&Pzu?-=f%d8)S1XBgS+xs- zMG#PkvY5x4GG=4eGmr%paO*&MVt5Q3Boq(?w}xL=)k`Bix^V1#clPMYS`fDgRe5&q z(8(dS7~eujHHixXEG)&uwNEB~r9Se#$&}mM=?T?ngSqh%q_@%xvO-2$3{(Jcl@P4}}BU;<`&L zjvE;pt>2*OmAjq4YUz67fWhSOVe;QF(6A){csYfIpgJ`9RAa@(93>8$?-YwLRymM} zYA^gnJGgXc#+&_et+008P(w(A?VTyvgosvvAthUYrF54p|u?;Dduxsa>=)Vzl|hO{@mkV*QM(3uGurQ&)>9640*8BOOzQie>VmHvNnTR^?Z!X3T?Q@o6YAgsP{)m z0K!19|7ERhht{^B$X*l;I4W1DN%zuk?UDul*y#2N$2XoYF_XW?$`nXp=2S4+OW zr+n>J=moRRTDV5bZLz1i#zAH#hvI}0>v`XWM8&Ra`R9l0wwj0)-|pEJ_1};q72DUH zsYhumCU)ljv{PJjhpJ$l=8&ItoJ2uu?vdF`FWi)nK1UJstGX@$n*?O$3qjx8Z{>#l zi)VB1z1?X9DOKh_I~{>e`8tV5ie`8pcb)F6@u>P-Wil&iuo4&K>!qf~NMv{$zD@cmNLl`v1&3 z5JWrgaohbD_vxvV&m0SkUH|T_rB5FV{X%ZoSrF82tETpUa3aF$Pd1uf-;R&o*I9jK zzd!NUPcIbXVZPm-4aUttU-HOjo@-lz>gd`r|Ko@IoKmTSPFqMxD=zl5cv>TLiVOED z!LU~n3p#Ok?h`)m$A?Nb>^iQ@M}E=M3S#3}Hv<8khyI}x+_A0s{6p-f{e9Pa8Z6T54FWh| zv~iu2b-y}3G_TRmMtN4AMdGWcJ{!LIl0xVM=cW<{722%>0{O1H%}~B8)w!tq>51jjw-=kSaXOtxYpS?r+4w{}`9igN zbkj|nm-V}G({tOM3fLGODq3x!6R;Bo&}uUoB`{X=r)I8*`zrUGjIp{-S$Fq|%A42e zcF>8ePCb7~3%7t7Kp`R^LE4}o>%FJ60HKYZZRfY#S}V3RfG`8*W_LuFpFg|fp`Dl< zK5?ku-K3!F6!NqA{z$~0!Va7!WhYslXF$iF&V|mv!u(Oc|0vGsl4{N$ z+GA%2h8pNUv+#mflJZ5Dp;37}W+^<5SrI5j+@Q;Hf{tjZ8Xuz>P$!`H4^s7`(2-kL7QH%1-V_hh5tZT-1F zvc*wj|8Q5L#;(0`EO_+52a2VI;fu|H6B#i$^NW`DU-ZtBrncYQ@nzLse>+V`^j`;N zDqDZzVkcL3-F5!MEQ+34YZGZY<|4) z{uQ;)K6S)<&)R%j0!Z120T`G_fiN?XS|q-fULlVl3$t+8IWb_~Ip$(6!L%EBd(S+) zZSb+boXyBlU!t%1qLV}CTJyOA(0 z*k``7yA>STJhJbPW*VzP>%@m6I=g8pAN#`R=5s6b`v-<9C9gT#UGc_u4eDHb-jlg??K1?&(&Q<_UJ*Sal-x_c+o}Bvc$EuB1P`qf}P}~X#DHqL; zLTw--G>!RV)2lW{PL~Cf?{wjvDaZ)YIr3$fF1yvZEj-uxTjkn`=SrJ~f=*7wjygM4 z@KuT+0Yv*Bb;3LP=MVhlk-@jzs%HuwC4flq-E-xZu78FFK%{_K(24*e5jnFXGm%c7 z^9SFd;OPBdtuL!38@6|&{lB_;VQ~FV&sa9q+5b_GD$9RlRMnsR3+a#IWb1pY5Taw7 z^$*W3pMKLjf}~N7pZn{>368wK-0W0lxLY22^Y(JfZweO`=Ao%#hs{vC`g8qmSr4}v zre-gkc>9P~*>!haRla}4(Nz!aU)lcIb%#IvLU-k5Z=#NJ+DgPtSc5?)Az~w>bjrIa z9DH|{nftnhtg`F>SO2GwVPNU#Qt^WBOoos*h=(YNpe`t@;OuqiX*!lkPyf7L7 zhfC1)lC;Q(kkZy;gd!+R+Un=)%ik?DEZx4SQt@KNaC^Qe;97v?u?oZI`^-Nk9+*sc8>Kl4}9oSWCz)&G-)fnS7Oe^Yjy zpjM502}sU-Nq212{{3g2q`2|tmdD~^TL$0lujp%ACjWu2KYPqw_Ta{nCkusXd+eJJ zh551n)YsL{d|%69k+3R}r&K!+?ZlQ)OUz0PI8Nj@N-;}dRpd2FP%9dblcWIi1vf$h zghsWUhls2w@uX3P7~*^aI{N5GT4T=T^`eu6_dj+|u^W~r`C|hc{|6gvM4eCHH&=Mu zJ5B(&W6O?xX!d`tuEgl}s*jec6WKG`woBb2Gk|U><({bg<6q@XeqsEk^`-HFb|8&1 zG@U{*R%LOc$idH@TsczEU1AI!ii0pJ*UNDh`>}32XjIH%S>l=C)+8D0)AjXx~p%Nf;7hL)!mAn0V{`mdfOIE*i zcy2(~az<$dmaQMQD~&3)3Xv1ryhMJh0L`Lh>oPbTYaRNQFmw|1Y9)hCCzn{O{Ie`1 z#YAjC+oi-d07%pPPfat>h$74=BpHWelCSsI8i(HanN7uFGI>|GdbD`i6-P$(Z+eBf zn=bUAqsHgHJUdD1?c-(v{-P;;Q*&~4Zf05Csp7%TL-ugr6{Fbo7Y^U`Trn(+i}aVA z$+p7H+th@%MM7XkWJw}{#JlMNm-3uve;RQA$iUzMcW;7vO%8_U8az>6{q&(ddgV{7 zDa6{dG1(nNSEKeIjgWJvA`1t`>!EA^6h$6DmKYhy91W zqkXqH)vxLJLRYDF4_%Jh6`Il2iHDzz3*$d^ek8&PisNh}yP(Tv?)W2*cdAhDJaVHo z&E)t1ML@d014j=gv2)y8@KhqU5H^p>MiT{XPKvVNRT`lv--DPP2chHSp*|SZ3XyF} z=vN?aI=6g1DXclMoAJoMd702_CxSrrKCuA6WR<$*7+pS8hvf z*5PKc|L|wO@?z217asTaU)48%&vZX9dM^{R6AxMy$+z+>N$gPpU^e>FC&n5UR^D-k z*BP-Y!K$6FN}~XhfbkMrd!fafJ*|-iaBgV-csVS0bw=1?L6HMCeyx^)--xlu_s(ov z=7W$DFrmi;7=ws4RI;D|SfFs+eq(KC@O1;{D2V&bF2BD0{JzJZo?5PMU4}V~MJcgH z04R1n$|X+F0mlchZj6kM1sg!eIGzKs;gAED!aO*wQoIm=bum7hK}dQ%_u$n1{%p|% z!puk{!T=zoleT9ce)itY*SjYk`^t26((#wQnLqNuO||$hhvrUvMTT$x#?y_VP2cY| z=$_9d&iX!+>#hn%pr~h~2k(vx;oEO0c$1C4-)jr3-hT6WNz``C;QqlQk8aL4$^u{$ z(T)rAT`X5uy>V3+a-o1f5%$HG7u#daPoHR>f9q%zee&Ua9-e;g_{p~ppYJnit=I|* zAp)_u3~8I(v|t-ej{*n~L>QR7ng{ppDy@B6pRp)RLAzkvGhf=j?26xv{2MO3>@VAa z4zKIap#@sX0HfaHN>ig5g{>gS!t}xJCYK=}MPb+`J2$%Lcb@X;;4y2f{_}-+lk@eD zyfodn_Ll-@{-0ht9ISrBwxLkE&tLw+hRT1RTNVE5rayRQ=;ru0pI;b`Jj^#wJYiOC zen(K;t;^TM>;q&dbfE??*na9akw#a(_ll?xYan2yLf?v8J9zll+mVph*Z0ZQ&e6M$ z$V&t5!c4O~64T^x|MBiKf7=ZHeY?R}*tGeJDS0y-O*91U~ZL*T4BPKikd(6rfe`+OffQ5u<~D z`^4DWH;*_8vIXt9AkfVND2TfSO9U(e-Q#zU20v!i>s199QM>Q(N1OU~eE0)Xm;c1_ zl^kneM~Q#{5gM?OmxL~A<6y%UzzI=Dfof6_EK8h2$^)G2t~}f>OyqHWT5oqYirhoP zQsHyve5Tk~&oK}Lc3b(G@UH#Cx1Gm-JHJrA^E1=%_FW%5mfQD}Tl$c9-UFAEo@h3% ze&=ZW=|6G()_X7c%0yW^qhLa>`WGMUSM3dV`p0L#cyv;$X6Mp(<`OP@rm%6=IXH90 zra}un2M`?(+@q%5b=?&bDJuyd{_>0QHE-!}z=>n==*<^2b!mE#P ztd(AKFhq7Oqe8Zr`TUp$12Dm?{#dmN5k6ScH(n96$R_NNr#kS!ea_5ZQMC34R-7O! z7~h^F!;n^I0tjSGKma`kj~E3>fn|)$9fNVc;d&mD(;j*D*Y@Y6FR&vU?~wPG>>KRY zr{coTZT82X{>-!9y35WxzvL!|V=+DWg; z18WwxP8a=+T-Rf*7w+1Bw9smFi^;r(di~3vNmTx*_LGKF;W@{5Jm{hnoF7p#cF?iH2S0HakC!&;eG>fno@4%re|Kx4 zx$Jo6B=w)gwO3TyzT1)V^3MY!xHHm%-7YJce&`Jrn{uUjpG7w5Y#Jz6^d zODAXKZEv}xDK#vBQbrI7*f4vJA>H;pYKuy=1Ax|P2T+WeTPjdp2V^1d^?&f8Np(nV zxpV7zPNIO$DiTTQ7^qi?bI$QB(SkFx0AS22Uq5gxS=b`aAFL_u7$=AhP4vmtm#j!) z+3$5%eyJPhR+jYC7poJ4m)|(@u_PRFi5n{)eDvOG3yYhos&N0K=uJDNzN>s?wGZvw zMF;o0@FjE2f@l{LCRpg7&x0&qRZ6Tv2VC5DNG=+xQ~PkWUOs=YMYHpY-C|?eEm!!( zW|vEBQi4lLt6Z#Q1hUrSlyy@Hq+DESlPtl|I`+uPO5-w}h=XV-_Twb?&7>0TykO-U zSJu)WFKn*PhM1&WJnui}c>~>V%jpuQ^cD&VB=EnGS zTmRJ1WpyYvLw)#&nG=|pse1Y26E8dwR+|&MY^W+!+kc>a(@<0TVaLp;k1oH&(15Ai zMsYZ9<|8Uheqq<6xtZ=sxiKycb|#8dayWQsNBhnVj=-KqQ%A!h{muDDZMgjM%XH|$ ziAM**QgYI5G&Ck;1*DYgXaxYU^Z(HT$@q>CD2}h!!aPWa+g~pb* zj$y*2TrI)?E*yJgesE;i1t_=m+`h^R*MzWw9uVU%tUi z`ro%ggmp>{L}_&`ps4m>+uqN)i^d=SW90X~@$V9E_LnMcz4FH~-_!$lmCHwVo9ePF z%J+WZ*`dvsz8g=%s`^0RoZ|a`w`)5l|HIo~eDcT>gH-OeP4S}DA`8bOC|)rc?4Ex* zZe8iQO2C>ZZtv*mkvlJQW8Xn7PN0XDT|o2w=IKtDD<<{ai_!3v@PA)=sytf;{;g4D}AK?J8BEa)#q;g;1lLP18`abjrBsp-*+;JSCIu+#>3 z-Mfk?sMKP?oB;GFfm3G!ltUIE>BuKq>lbAFh-HpW9C+zLy_W~N<;r`mOG;+Gc-xBP zqsNk=9}G@@{C*6^-}atvS0yejJR6@V9yUL5%i!qoCLFI1=-}3CwtVEts+#V*^4;TL zzy1w-&D#b|l;k0E*Mb^0zwyM0WjDXQI5(oZMFwNYLF*>RAN$4=t^Vq=`9rytQJsXn zTW_N^ov7Vdz1bGzng6{@4|H&{>a_Bmg}-@Ww$+XG*r*9h{_NN8Jrb*HMmi^Ve&d^y z<@$`b;veMP z%d|p#jbhxnFFvjtk1h;c`E#SXQKB5$*d>J`0E$q(_|(*@cW91Xr2%RWJy3Y>SiwP@ z8iJ`~2kW_aIfjp(_=o56&A7gzTR(LC>ScgRk~3=*M1Wb+wp1@y4cZgUlxERD8n#75 zT|fA^_~-xT+}yH@-Z3@p&N-F#+Fu*?Ds|}IV~*YOd@w(zqEvf-4E zNg(Uu1wbekh|TP->gEVrW_af}_8eJIb51-ubjgpFtzCC;VD%f{nuPU|EzkYYUNfxk zd`n^f1+#(gLF=g3&2_iFV`bxrdahGd(WW2S;IwwmpMdh)->_;Po4dN|_N|_YTy3(9 z!$|FT^mEO%zj#H6t7}f7;|UQ-*xvofV-qbsmV4`)qE9_B*EHt+7x$O?TMj-#qjxNi z=Ncb8#Q4UGzx?USnrgCZ--8P?w_X0|eXHJV-RQ`}Uwy8&yvz2n$DTf@P2J9yKKbwo zsO4%97cukNJWj>@9ixN&PR>f+@tw;Az1#Y zWx+(V;zrGPt{?h~z0LJ+ersX4T^jGs$NKoN{o%EJ%4!QB$ysEd{=w(D^2S_OAv#At zbGY8;{HH5p;GufAzhILR0X27h?5pwU%|omLQBn8kA$!$mpqK$f_}JE}Y3x`Di{;ww z!;e0Qn5$W7MfaZ*ZFzqpk*m0Y@ zcdk+@y)-?z?dnkB#8`d7d)vrBz3DqAhnlczcen!9|IR{l!>V@MSp z;rx~;$8)2;2L}!mj<-j3d&oUDV~T;^7tHMY)Fp0NiA9CXLI8;(u|mpV$|8&+A}ru= z%D>mn*MpowyXwk(@r9@bSR1X74?TE#`CBV!E zk?1(0~@a-f=c(tsk&V|?p3R#fXU`X( zpQkOCG@h$$AC3EFANbNlKJb@U_IzefQVmYVtK`Lpo6Vw_*@V63;k%RAw)*E+mEN$q zcJi@fVPRQiV&~)HJNbda#8a&+U#fX9H9j{y$%Au;M^F0JwPD{i@0pp~QCPcqyPs>i z@kyGTIpNhV87d_z@gj%4vr^g;XZmDOE~zp|$C=-~zj*Y=w@i(gLMvZ}y2^9XDVRsU zX$raZZp&3&<=1Kp;TF$JdmpnM3GLKKUpto;Iv{js9{QW*KXl$}j+?i5{^=F4H6JlL zz@QH9eeq-=$EG?z3kr+^0YnZP&nS2B#2|IfSuXf0r5*0*Kx=5_WmjC*v{Y*4663mI z5?D0N69*2?EPr7}j?OUg}HK;ratd2PO+6i7##IVH_Muq?Tr^rJ>_ivnYHVR3=UdHFu)K97!?A#h7}4>j9OTwO-KO&yvYCXmxB2FRDbM) zG7{9Cq-%1a6BWL1Wt5Ndv5kIn2>mz>`-xX~+!@R}sVYU#%!pvEcALeHRCe7r-1*dJ z+B29>5=T)ua^v|I4Y122Y>@NUS{}(w_SfdC_NOjd<2liITxj{WrBCKy?ERZy-Bb`= z{TG8*{y@2cN2NB=H}-Ef#iHl`vco^>thIiK`A~y1-^ZQ&OJ9)bpS)loH#H0)li$^$ zJ^Gc$ro*t_D6IX}fgr!II$yi5|Aw)Gqp_u1Pxc2RBhUZ-oO{P@{q6&;+{N|wpFO%U z_wQGhe#JyrLgoHrA8U;5$hrPKe|g+lefi|c+?9X5zn0U3C)KuP&Bo(@|AKZs%Y(!3 z-BA4e-%Y4!=(5MYd~d$xJ2w@c`}Owd0iZc=Ew2uaY2PokO7O<%3vMmy@WoFR;YYWw z2sCxw=GUHn&Vl~KBs&#(vZ*s&Mq0MmOC+AJv>K0zD1|XwZ zRIG!k|G97A7dOW`m>%kA`}q$Y9>LD-*WCS#tv0LmpyN-JsG_Go*}drv@rZ#Oii9~# z7zmZ1M9Ky*U(#%JtQ^}Y8BJnE5w=udIwiwIz^9;Z6b3jw1T#tVGgHmAS6z7F>W(hp zd>*Xp#>$^B(9GXGnABKtW0|{P`usTd^Pjl84Wl}Oz-dlA^F>Yh|1#Ws_Q5ZN%?-7ZJ~ja(t0XCkA39kA7-6f>2E~o!nR`0*ZCh`=_vs79uBo}_3YY4vS0#BVa0sE} zKVFZTgHy{df5&-E9s#rFinmt9Dis83yciuHb|Th#;HOrst#VAtl;T3-fjR%Jo5qTB z*X%oX`E{2J2>b3&PM-MsSAKiw(T|9i^e0M(lX~{b8%B&F5@9>r?2H6Amm7IsCFlwi zSqQB)_J)&(`h4}qRhu>({)ZMwl85{zj zik&AH8t=Q4Nf3dsg{~9MP9^QDwq2b!O6&#yu6u2+6)F@!1niVsAKYpF{m3HWgMa`k zOgu6I^*s+QxQ!%mX};{m)yuBEyyUy!1}=pn!eZ_9-#9j#EPKoL>PV1)70sbQcP4hm z-?iDdFwDu?pI*IcLpSWK_=EaT?zdMIraW)gbhA=X37{5(;@UhcxX1r6=C#)rOWl&0 zm*L!8$3FDH^g;u?x$~}FbE^&N$%-AnQ(NJ^V*@j@H*uFyck&nh)HApKU@6&q&!Ke} zeB+se*B<+ypV=AtceR|gS6=aNHq5+ZwlVc* z9gLbQob!KtdGOf74;(H`t@`=LJ~C4;;IS?5{HGHzp zc1^_YuKoL0U3GQN%eU%s;%g5~Ufgn2Tu9EUmNPx(iTd-lcE?H&6W;U1x@&*_g7qpi zfL<3uQvh~%J^YW&&d_gF1{@Ye47E8(%0!Zuc!4Ru{UgwR=$dN4iI)#+ztWKK*ESVi zb$CvgL0D1qo6iQak&E`kfNk@Wf6^UOKYO7!xi?s~!nKHW;$zc8mj_SZTe^CyZ3B9t zDdrpKyJ*@cYc5>~QPb&z1lm9INMtB;N1sf+;+#7 zT3ff{vhu=QJ1*4f!}eE>H0CDD6s~NxhDz2qv%{??KitXZw(0gjT%Y;HZaaCyuQgUr z{8l|~4ZMHc3+6ud^@uLIVw?bCchM7C=%XVmVea7H?(cf=+@41sdARY7Hw~_K&-I*i zwq7#3S7cswpRGnsqprSnc_dyl$X{^vc*_pYgQsFxacke)`r`{!TyaeT6fnP)Fg|)ltw%QetsH%!Tt$p4k7hi*aR*fB=LBAO_v) zi1sD67D}#Nar5=n?ud;UY$wT!007d=&G~2OrW^bEbtFo_phN3Tij4h4|V+f zog(=-c1ebk1xF@2M~=6WP&!U) z`&PGNJeYpEJ-h7cQQLBq@{fJqhs*Ze7tD=q9h>evx?^NR`<~hHy>Gd==03jAxS$(7 zu-g>I_r&q>pa$EwZ~2=#78h3f7hcw`93C7>{Ajk|7UpF52m76;@BijpLHcOqONaW# ztI2uetv!2^;auc3FWUnQ`KCC)Z zxwA56EGZPRY!AHI~OCVw?*)Zcdk1QVxVsv2t3&-2?W19z>&Z{TrV9lbFB`Ahiz4~T{mcV16ONXt2*^Ro_W*ioU0hxQpiEKEJiI< zkP`w2TV7BXh=!sw0&}h&04NVU`QhCvzy2+2J5FaXjJqm|ynGZ@f;|VCwdco%i}UIv z1#K9Zt!m-D4ef+Vmy^K#TrMdL-6+Ge4|GlU!tv$9og+u?-ZwV3=^Cf)uu!k#ob*kO zj7K&1Q~Rd$TsT^sJ{XqccHTLa6zX&ZkYe2g2GXLk74~?+Gz0bX*N$_PFW}#*0!1;C zwpy9cajo}LrIxO&ozXx!<+}5op`V6n9^vyAtY(_Ty5RihS`;X!a>b1oYU}2L1+R>T(H8+EP%_lF zEppr--;Rvy#RkCXK(hMhSL6ofita3sBO`8>ws_4ORn9u*{#|aNJLo24na}`e zvh>L3kMX7N%x%cc4#ab+L)Io*lA=l|%)=-Y1Za~?Xm(0y{nS)iOwlQj9r zS7T=*=8%#XKjrBgZdl`bT1>))nB&~ZPu=Cr=f;jDKiAJj5s@+Rq;)F_LO&~YOw6V$ zXQD%oUqr23c<3XKNMYnh$X4K0=K+uqBF-P!v26XVTGgVlFnH;UyJhqzx2_65abN%U zUAZ!H602r9PUGpO`0Wq0@|#v%a_M}ZV|?44t8{b{xIGzuw$>WxuqikAL#`YV#-EocH7(L3{bXU43HbLl1*@`P>CR zU+)?pT!`E(+}}n>>YqA%zlUpAhJDYkyDyRabfKpFbM#i!X}Ft6()1abb-R~5+jraA zw}h=Sw(PQ5Xp_6bM#_-RY2}K0X?|L4)CY(lw%s4mGEt@+*rxVd!UZHuVsttq|8*e= z!qL_*=l0x{ER%s%mFl|gbjd*x0CKDAw`|<%Mg^zCu|cx-KfM{k~8 zyU*AeZS z0N8;jdv4;a8qC#hS=}c~=M@n_G)mdXyXbeHoW=ZLX#h(kIgDBrMTW(5(x#;Vtn;NH zj4dkd5gThMiFDr$C1^S>=RrXf5P)JkGWN3rK7>@OKUA|L7nU3#YkX?wp#+`Ahvw+o z_vEMh4vkGMxDS40!T9RhH81W?a;vU???_YKHkG0NJ00_XMkSF$WxTWeu0FNA7 zH+0Lk8nc7osL5vI3p=Mx7`CRgv+)NzNE!e%K`j>-S~-J~RXh75Z~D|UJ$H1u4h}zj z|5X0pZzvAbpjCRc3P2|S6z{z|TzkieYu*0D?tK4TG5NjW+|Il2865e=egy!a15P3@ zZaXVxY+?7KpNj|n!zGT6WuWob_tNl>-Z4390SJ%${-N4XpX#tRnzUp0@A|Wo0hLW*qbH~!boy&Y%p-*2lPQwI)x{&44(t-hvOvGC|0Pr7U7zV9a|elNG+yYyc7 z;F=~s^R=Un^)9)?PFKJ3rOC?Zzs>c{-1|=eo9>w0^WgqYxuveza>-ng92IE@-Lw@H z=JU0gyFSt?dDq>3p}`-TBnk75lf&p-7C>6GjcBva`?>PAVtiX)ijBp_XDk~@d<;>< zN9{LSwW#WnvI!utMW@@svSSiE7zX5rlqAMt;5D!{FdbDIRDtn#^!rEGV9=cl#%#(#k2CQfO5fFR}HcfC%xWNSXNH*bBHR$4NN@-V^JvLsUUyA>HfS(IFl7O z^cd9zA-2GkmT%pxNsec`m5py|Ux}8HOzfB_un`zYeILL&6yJ9E)}?rfA60Y_t^yE24v>W{Pg4G&2I}X2~sz@N9o`Bw45!6IcN9Ehr7k zps5m7>dZN-uf9G&auSQ=>d1|tKG7~;HQ3GD71P{#?8`?g1JR}}o$01t{iEBCj5Vfb zc0S_O*WX?}Io4(eY;M1_x4q0a3C!FM-i6dK{v-!SNy6-sE z{NZy8UwI1Wdg{VQDa$*-s|Tz=!AZq+_|^hE#4((^{|Tl37%ZfMu$>^;GG z7u>)FEW_{^1S4)uFJM4fK3_weyYqt?LuwaJr-bUJw=VGCNjZ8!F;=U7M^ zzxeUOy78lQeCK$f`S3$0$D+%|2rAI`sxbfs0N8x;#o~2qs*RE-Z8cLX&EDRpkNwk0 z(Qhc(76T+4h*<6G+#C*6Tcf3eNzrFUr+egUrB;4hznyPXeSiA+&Z)r*A{7UU6Y`S1 zU;f5IK()4B8>DS5mU3_$s5HO+m9RQCziG|94JN)EzL@A0g`^I-wm4{`9^(?(@vrFxHcgJzqNirfR!QUu<_BQ~yl+WKb*QZ-2Wt5HhMT3Bb>; zBTW_Ji!F9;y?VV>t@#s`uu$H3d|8w~b-A3X0J6k)GLWWNHroEie!arfsxDBUo_Ac* z^8m#dC^R_|2@pwfD`%M;jUo~UCI~`>L8oghX!1QJQ5Xa@C$K6I&lZM6<*?&9KESad zM9mB^k+1eC9mHsu001o@0T4z-+p?0N2QAW&Tku<6Dd}|40n09*Jh12b%{oSvmHp7- z;$H#))A+sPcg@JhSg+)$E(k6%r zd9mTbU=n)GR?#i!AcD%maHWE60O1m(2Xu+{Ar%&4Rrwc=MHI6jNdz`{9{q z-LpF9q%Km_xVs(c6$FJ=XK!<^Jy)O9(bTbtV>4*t*4lOdrW$5-o?hyQZw#Lyp=9W4%GT#KZn(OG}_#?gYyrM{EVHZW1R@a3_8Mc+vK#97#V_1@OczSq-#Frr_lW9x5;8ybp>QbbjJ~Hhd*Z+so+vj;l?$g@ zV-fE7^H}w7{*Ay+zc3LE6~6DnxU_t3;(yOAyX~Gh+|cmmzVI}ux33zUTmHuf7V^Q4 z+0GCKm;KQC5l-4}XtYsa0wW-Rg0}c)4>oo2*RQaVXFfSP`mCznc2gygbEO08Y}uqu z@kAsr{jC9LC6%aEz-ADz6NOamG#u@VRmx*+#fJG}%p!)_N?Qiw4mdD3;^ewX(JG0I zQ69$z%4*JWZBovkkZL3ECEm2RB0^1U7>628vv_!&uOBF{C~5?Y*z4PoWdZ>JK*o=W0I?ItZlPO9GMJ!QtU|X^j1%Ur zb}WDucA)ELnS-dxv4)P<>ZRl+r~nC2A&0rVf{}@N-=Tc}f)7!Lv;)8%QF^~0uDVt43c~1 zaIVzfoC^Qp*j_L4FS_~`a36ZU*U zCX`wDlxZYfcDyU_(81O~9e=uN9MVy)vVFX@rkP)VLuW8(?HN z8VePtwf#yDZ4uf9lKlCy?yM33P>KbmQ!8~;>3TP0Kgi9meBpzI6WE&Rpb3+{(y9%# zKHq5$p{y?DnL#VIi80;6%*mJH+Y~B9Ce~iNuhtqoVw$`67e|8Y#chp7xmVpui$N?) z`=CEm@@5BS`a1aOg<|}c?!jlD7}|VmzY>V!I9Ik|=xa;k*N={OY=P7=x7+EDv3=iE z2-lAsTU`&NDUC;_R_AgnBa+GN=W-5lyKQ~{gZrHjoQdpS!Z21YVWG|_@4JGHo02GK& z+8`8w2yMQD2%_8Iiw)R1u{M6Fk`7BC-2m?_Ct;#{7Dxni8WekF#L#(0D z74abkYH*$;M+tR{+|inBlmJBvRXh}m0!EFT03bPHS=hJ=BF~1tft+z#${~ysJwL{? z{6`OoK@;b&>t%iJnSsQnDg2qFHoarW2k{rXPX$I9&w%5AbtJ*A&RGSC%PQp*1fzf} zoVYrRrY!Q$=j0p%p1T+}*I0jcpVstP8FM$9wZ~+p? zOG((G2`nZ~fed z^7bvu|4Lur#htIY&ux9fXnO@b(#9y1wX1J=;V;apV++f2r7P~N^1xKTRzJLbQ(i?s z`$=`nO`}+lO&=~4bz7vaO^GD(t&>|H0qhm=jb07KnuMdhRF5t8Q9h zLr=$yoQ%Ee4=UXUcicyG^H1jx-9#DGp+Dc37#9-OI&CYN&I0R)e&kS?1VEia*vm}x z;vCG@{#uzIwr;*zOg{;3BhM=yTcWw9&H)OfRX|Gv*~JiRIZ%egA#5D`45 zKE2i-XKb|xc`rLHOADBUOOhGO4qlrNYDxZ0BHZ_ zCBNN2ZQmR}F#p+*EXETbJeF#h<5 zXA3hwer11Tz+}I3yWd)wm<W+rD(97x8;s%>BuKL6Q=NBqB8;F&EyeGyq~xvla-`+vW5bZl=lOm7~nM5&nx z-8?Lqzc?7$AG+8{24O)R{J=fKM>fc&L61QZ(6b+Z%u6=D>H3v!$MV_Ty+i+IecGrfH4d8?@Ydyi{^?x6!-NKxklPH|{?ar9|Otn^; zE;V+4wT`OvbKCN*imBz#@p|+rB`_dqvt0m0P(Vb26cZt3wKtHtbR^7(Y090-Xa`a@ z{$4_HRt=UqtBYj^tbq1s zqmip_*r@WcBe%BKa}=Nt%O^H$>c;&6j}1>>x9gcDx!tuZ>Tt) zp}g2&En^ar=hDrm7qR1<1w2AQrASGsbc`>dlHkz)fW!o=-#DxNE4&8&b z`ud?QYl`_Cj{(G%Py)=CO$2^#p-$(n>Jj&3-+6ED>ttOa zB%c_Gy1M1YN#f>Cq4rC1;bB)hEzboLr9lK^PkwcP%6G*jh?o9NpyGc{OUzl_BFyjH|j#A`Rq3vKevg6 zQG=T^TudJ5Mw1tBUG`!h6dDI7M(V+qt$m#U>^XT;v@BM^5g=#DbckTZ0xAtD(XK!* z;UW>Sf`BTUzUH*0pJwEUh|p<$>?Eb{00yg-1I99o2q0l11gMF6ES#RJ1Qe$!;0V$t zr4$;L!vFYq=)CEs0q%l2CA<067ipkiMva)=IRY#-M^c#u@G|vmr(KQG+x_<0S~HO5 z)bj}}*$^&(Rr*$VVKSnXS6*^adyK8;giWyLJ)d;uimLyn_bzjTy3Z^ELcTKb;gfOk zjh9!g8?d5DHI8>)RwG~$i};7l!4O@c-!B*wB_sc>p!|VOs8;{* zjir7S8LcDNh1_Z{wxBe}ENunsJUcdgz zQD*SMD1X%!QH_p25LlMT7*P&@N;oQX^HHZaQ-?Qh%e`flhJ?}4Y!OVLgK|zk=;X(K zRj9i7ek9B(7abqSQDPSxEX0obtAA40|H86~KA7Ec!(z1116la;BX)c{9drGYvkOnX zM4FqW&A$@o7sNzU&m4NZtuK1Rs?X|Nazojv6|=j`*DU^H5}38?&kbty{`Y6-)e7L_E;-F z*7@7#yU{RJb=0lhd+%(}|NY;ex7>)Q`0HP+m9`IEIwa@sjhwdii@)p;bxK;MXnxN# z&mM_~rZy-Qv+m?n;!MAAuLN)U!Eq_XYTp+>ay*}G{mB*GTt1X@f+f1Q{w+f0&Sp!Q zZLD2QZk7^@0wMxJ)&dbJ4L#1GtgrUck1QfYJ)J;EEP!1kQUKy9U;!j63_u8E4Ji_% zGHT+0gaJr|1rUKrqe&o%8ZSPka+ka#pI6G!R?hCE@oEIKu6QhqOevJ!jJT<>d)ACL zh?KP?=_MY$2IX_g5b$&tKoam3ZXeCffID1CFt4cqfNr<&)}w1K+gbmGk!3kSn*&5r z4AgKPuG0mFdGQ+>&)>AQoeQGS}@~7MaU2=Dfm+A6cD8 zm)pJ{MZT*vFJt0JyRIS9U;xqc5K%8alGY^~q}0K{GAT${TdixA)(0QVEqK2>lB)$t z_Krf*6bsoF_%5_*u_(rGef8trLe(tv9s5gossFOY@r2ZUVS@I);Mb_ojgc3sadgE^ zm*E(2<8$9QSu&kBuf7ImUUK5$>mB8!ewzRgO+E(tmOID6(FdOX*#45qt@D3jRf7}) zklQ{q*j6_$Zz8#xljN)}&JITd1D*KjPU4*$q)wZs2VfDbxOx3X;c zpBEqfugT{swf>*pIuN+?UaKSW#Ly7dBNbYhSVUVuS> z&)TUMzlL~P1CG-Vc|~QA>DAMQaSDnctaTDOD?`_hF%12D$7r<5DY~~cM^~)stX_?+ zypuSVK?$=|lyr*D$eLk_B~GkY0+aI_GX-JN0D_Jah1#(lpPO!yi^u~2sJl|d4p5Y{ zgt3=*Lg>V?{8t{Q5%i7}&P7)dr|(Sl`AVc;0D zBOnU|1hA37f?ya95Fq|x!9emS35>vrVjzMokWLghv1Es`DQmH0i8fb}Lk`L5;Viw? z`rf_gf3?YQUTn1LXA_+H&U*irhmO$N>}su~Za79SZ+4m(bfS2a5|4i4IgN{> z#4msQt66er_=Dye^zJi_<98j6TlT5={Kt0p zoh>gLg`WAJ%}(;rzE^iQp8ewAo=K7#@<(o6OANe#V{NSD?^;~!&Vyb;g}tTDW~d zG-|fi_vq2@wb%WZ{;=6H>#GbzG4|d1fqi{K4@o!cck}PPKB?J)|#75Z=NohQD@EA%3jZj#_S~hYZP{D>& z2{AEL30(HIWeqc{kb!DNC8#80TcvDjjs-eHR%Rq*-;FYuvGMP2vG>Q$1)yTIVPV!y52fcn8 zxi!=DP!eSk+5gXz!$({Z$A!swIi)7znXUy938)vdjR_4&2Nn^EZlpD1BAcreN{~d0 zoJCNpp>T+6eZ&BqG>{KCwR7G85MM-L+GCVgY`iDKkW^fu#Ke_I-zHG9GQh&ZTG7us z^%O3HGrk|y`W{G;DXF(HW1wOdP;@F+Gmm1E7xk>q>d9KPgo>ws?*|8xfA(Y1*@LcC zlxBx7B5vV_2c*>!unfzpi@$vI!k$mv*ZkV=oob}M+pS12w(X@3)f-*gn8|%N`oL>X z9J%<|;|r_%E+Ag5N1vTPv51$8XT3D;^B4@}@{_@LzAWjyV|xB?{>Sqdc7OP6!RDuO zF)FH1Y0sQ$*)xq=ubCH(D8Ez^B}%a8)Az1si~DlT-Z%eSyXqhNyYD$QcjzDe?bB&` z$It$qJYIiy`&Z9p=4ib4oYuRwl|0Xj#Sh$dPuwniJMGS9_rCU>`Bg9L%YCl7taj*w zH|{<1YM(MGGAK3LtL?XmXL_S}J6(kIg5+mB?$x?Mfed;=FC{IaYf!=f66 z_Enj=?6u=XA`N}AK!%t~!NGD(Qw2r?hls%(d|m<@!%(7Fy?9RqikQVaQ<`wOH4s?q z00pT!23s@o@SHefilwJttS|hNV`|RVZR{LLFWc;iyF!F0fI=!cL?W>4pf#5cH*MnMomhE2|Ke19E9 z)CUUUd=71_WX2i^#A&47u*;LeWUKW&78?!nW{rAr-Y7Kg)%v#4%kmV?6v%Yo(te@D zxgDr!N}gLKDG-|=;0(jQ3S@Xk4Dkv!l}~_8(@cPd5 z`mt5dRP^d5${A9xEP#e;68Kdfy(x6-wK zZngCjcMkeZ>ST#c)J5-oHvQ&K!z5#lg5^ zq6T;B7t_|#t9|_b7hT$G&&@u3=kBHNXT=M@`C$CG*az;~%w@J9a2dxVHb6U#Jd74Z zQ1Q@2MZuN|RfP}&Xnnz^g@C_bo1;rqOi(Q1mT39tAX$l7}(%cZD zK)xUm3n-~$$2zM}m#T~5*dcX|5W&?dK_USyZkC!g!- zPtN~`r*BH1+Vj*uwZk)wJo6_moNw2XS{6sKonQLJ$G^1{w>rh*u@lGlnfCm;&h}q@ zaanG((ZCHY2V4Jc3(e%O|N5J4S3mO8Hx?dqW_CS=&pP!%qRGuKzc_R5*~112W3@3X zimDD=fNFU?Z$cXCknCv@Q-T76V~QH6=Ae14NOP`vBtVh4*|Q(29r(#bOEc0<$a}!56thar-sa|W0#;MsF#pJXeV?8DqjP(|bJpYe{G{V@ zp)=uxMvXzJ^rcWn0-=J^<}0azz*bf4c&WZ`%KgEpdLm*eZ2?gtRpK%z09RT91fw9( zATS2yfUHc2kx^~Sn52g0Qlzc(^Xo7F@yqVSM^B)icdKyDR)Hp4+6Hdoirdp(3o*EZ zOR?UHbI+Q$UV8aItt@}!7t+Y+xcg<{{pyOEI~GK8vYE>%>vdZ%JbH2ArU!oism6W3 z!)x8_?ZvwUF+WS-6vM>dVdUs7xiw}@$+)}iRbg*ygBd2v)HWL_z;2pA>|=?P8&EP z>FJ)Ie=u&Yhfj!D9M_I4o)E7lPNL;BQnI6UrrYfGj{e*4{`+3n&pq^E9Lek1uC>#D z`TE|abFIeF`H}c^)-l@JX;V{Z&)6dn!ya5ak^YFC$8&6#P_?wQJsa$V7>i1}`^1*Q?KeJ~k1{_;lQ3H4DwpFix}x z&QOx~BKY%PST5uf&;9YOf3G;*Dn7I@dSqCYqQszYLLY1s#o26MC{8UPxl4K(#GQJj z^;38CH66Kp2FK2=V$IM)cO5=FQ*h(nuJ8HC!6D%Y#oGIBe(Xyx^lQKNvv(eM>v?P1 zbyekS)davawiE-9x@weZ-*C|MGc2y4Y-XJetIzv05qqW41Q3D!@~)r$iBtOy4d&{v zxztCE`{wRGyJqI^ot@vcmVy9Msw4ob=}5R}P))Ojxvn~qp0 z+q#2V!-sC)|1L8~GCxusPVP8&M)SQte)q9@N@zO!{!ABR7o|RdkL!)~{N&t!eX_Cs z*pqyFqh}|Jq9&tNjB>>oGn^E)dGrS^x^Z!x+k7COcg!~_*}#8ib0 z!GI@X8ZF?k41tV%U_0zkRQ>=Ui1$))h?8rWB?3g$g8_(Adr?YcERHLw5npYG=qAI)c7NLHoMKvFlrOr6y!gUE4uUQd0~MD6+i zc<0aK+z+DkqdFC${_02A9&BF z(!riuv+w387w*)KH&@s=(@DF%+33un^wpa`|Gyul7f&~4_hrk?Dc4n|Rt>2v#>Q2x zR(LclY1L-ktFcy`kctN}K{oXG#2|}dn)!tra?eRZnsXvzTN&-C#a(Kr9Oj!#9%HLR zM;w*1E~_KUW$K$JE%W3;PJoW7jEB4dVxSt&AzVgWmgOU)^pFCEK!8yY5t~SLRK{D) zB_kIg@Bj=Ub?N$l{vo&drHf8uJ1A%(hPRD7>-(Bxy%=|M-_zXI}gAXFe9;&*o3R z_aoU@VqUBwAY*e;FjNwCg{qg97i0WD1%(kgxF;=&)c5Pd{2ABHcFiOm@kNxl0&&Mv z!Vy!>?HX%u4d3f@p83XeJzBeA_a36TD(prRe}@xSB|KeO-_`5tb59Zag*GLO^4L3X zXqc_1u@K2v@`a^le`SwdJhiu`eQTGK#6mQxY2lpnaZ*lzA!l!F*|k+te>otP1uPg@ z0IU!Y*q1|tNhJ{cA^cOmR$qArA|Z7gI{XJj!~IkYgDY=`lw@I*XyOs1`3*=xj2g~- zY39E7Ye#p-xo|t3VBfDdWFrMok{=-vW+Ujb83R|G8hp*TT=VUB9_m=1F}?3Uy5?@0 z@6P&p0x`>o)FTz~V!BNey5VW=Gk;d%v#lE*Z1vARnjZY%A&N)63axmrBuoUh3?)@z z%jkP%RWPSnNG9U#LPQWzQ(j&{1O|~WC3-6Mvua%wOn=qD zA1XWrVuY72#Jd_?S_leLPCN=GCcaG3&h#4(Y>3-eD`{!6gx8yyy8j<6x^D{zB<8%R z!}-if6# z^1ijEdF0l2-IAs1dAe~{J-3q!%VsA!&Gn+uc)IL%6<9^bIQmg@P#X;Ost%VOGZT-d zB})t)%Y~z^`C;E0HlPH zd)&c05tfR9CF9!lwb#FvoxAbDqciHfnL7V=V|0Z%KHf6@;`wg<{fov&YNoEpSBu0Q z{KaRw&)9vn*N=jv$$@e z*vQ~Q#st1rxr!}C(9o_>ih>4p4h>S+3IKb;6iOziV6(^UTB9t} z9$y{%7XWAwbaxP$drpZPjCPTYR*+gaOF#=Eja+M zRr=vZ4nPo!i#Z4aQ0$#2Ll_V}Yp>;XRIev))yDUCu(hz=DC!jn zN8kO8GjpHacSr6xAtz=kxn_W3(A~}6bn4F*hpd>?`J_9ImA_i_BRdE}vf{mhG%}(F zQzh%BMT<6j?xt}7S27+HqlJ4|?<^{$A+ZNgRR@5pfPVEQ0$09yWdU41od-}b(OA-E zHLhg1^<;f~0E0qGV5z<#WDOh5EitJ~TtB_}31ITZ(_}oKu7392!F7Z;pt|wD`_xN! zfAH?)%p9}VD`?W)vxBG>nRcY(*eHm<;DsuBKDA+b#=W zazEvhJ_?H}X-MlAAyPn0Ddu%4xmQ)ml#zWJ+tXzRG6hqjdR^rpS1W&2jd#fjQt2mz ztYh{TBCKS{kcuY(VW`G;UGnob#`Ed2)Z`Af8%jpV=#$@g;l|tF+u0R4BM#O~{b?FS z%@lc@<%Z|cjQerKd8irQ8ZZ|CH{*g>$rp%76?)~xW7`fQjOVouqX3Zz3=;|>79ldK zqM}l{01VX)K79z5Y1N8r@`8WF1C3;{^3{7Pa2TEtgwIuz5rpnf{jL zkE}Osu8dwHp4f=ea1DEVBB)@Dk=}#ShaZxCL9Cb93mI1NVZ05U8Z+T*Ob&dRcz-h9 z>6ka|Z{V$7ru9MhZJ{)(UHi97(B_`BKuEGY#+ybYk?x62~H zD?ZUlO!tL?DN&#-N-Ramw3He0;k~k!0P2XChD_bo2V6G<2zQ)=ZG)2C`^pRa{HIUM zK|EiR2J8+l9iP=iafx1_mdkUQEkvR1wqBxgI_&@zrj5|52q-jw6CqZjPfgSk+7xW}19PSKX*>+L0i%u^6xwA~N?}*H!>hSq0U%7IUWADRq~(em$bl+5 zpQ;ulRp?=7IKFcsfAy6YKK-7=OJt4vksbeU8*DZ{tJYxS20$3_rOZ!I`nYo6_iE#q z1eB`780rZQ_oHPe9#VKBA~T$*w*1M^@&9eH7}*X+cy0aj$*;f5*9j!5WUi$dvdyg!a2X+u&rmk&}ou$akjRYs2LD ztnMcbD!p)!2!arYhFDz4;b$nZFRC)#5h+&29n0-EFACat;3{6Z>E9`+I`drbm4BnD zf9c+*UFY=%RRS#Abi5xGjOjQ0NP$YFtgE%8uNKw)a-~Vn)wnJ{ACiCtkqU&%u-x1C zxeGqQmYM#gn&FS&BesLLIo{G|YJlUQ%s$ji}0iuG8+u_B%Syw_rxQ1XV|- p9lT>e+sRpNOa<0a{oKU&{~st@#QR@gK^*`9002ovPDHLkV1l1u1, + image_std: Vec, +} + +impl Default for ProcessorConfig { + fn default() -> Self { + Self { + do_resize: true, + height: 384, + width: 384, + do_rescale: true, + do_normalize: true, + image_mean: vec![0.5, 0.5, 0.5], + image_std: vec![0.5, 0.5, 0.5], + } + } +} + +pub struct ViTImageProcessor { + do_resize: bool, + height: u32, + width: u32, + do_normalize: bool, + image_mean: Vec, + image_std: Vec, +} + +impl ViTImageProcessor { + pub fn new(config: &ProcessorConfig) -> Self { + Self { + do_resize: config.do_resize, + height: config.height, + width: config.width, + do_normalize: config.do_normalize, + image_mean: config.image_mean.clone(), + image_std: config.image_std.clone(), + } + } + + pub fn preprocess(&self, images: Vec<&str>) -> Result { + let height = self.height as usize; + let width = self.width as usize; + let channels = 3; + + let images = self.load_images(images)?; + + let resized_images: Vec = if self.do_resize { + images + .iter() + .map(|image| self.resize(image.clone(), None).unwrap()) + .collect() + } else { + images + }; + + let normalized_images: Vec = if self.do_normalize { + resized_images + .iter() + .map(|image| self.normalize(image.clone(), None, None).unwrap()) + .collect() + } else { + let resized_images: Vec, Vec>> = + resized_images.iter().map(|image| image.to_rgb8()).collect(); + let data = resized_images + .into_iter() + .map(|image| image.into_raw()) + .collect::>>(); + + data.iter() + .map(|image| { + Tensor::from_vec(image.clone(), (height, width, channels), &Device::Cpu) + .unwrap() + .permute((2, 0, 1)) + .unwrap() + }) + .collect::>() + }; + + Tensor::stack(&normalized_images, 0) + } + + fn resize( + &self, + image: image::DynamicImage, + size: Option>, + ) -> Result { + let (height, width) = match &size { + Some(size) => (size.get("height").unwrap(), size.get("width").unwrap()), + None => (&self.height, &self.width), + }; + + let resized_image = + image.resize_exact(*width, *height, image::imageops::FilterType::Triangle); + + Ok(resized_image) + } + + fn normalize( + &self, + image: image::DynamicImage, + mean: Option>, + std: Option>, + ) -> Result { + let mean = match mean { + Some(mean) => mean, + None => self.image_mean.clone(), + }; + + let std = match std { + Some(std) => std, + None => self.image_std.clone(), + }; + + let mean = Tensor::from_vec(mean, (3, 1, 1), &Device::Cpu)?; + let std = Tensor::from_vec(std, (3, 1, 1), &Device::Cpu)?; + + let image = image.to_rgb8(); + let data = image.into_raw(); + + let height = self.height as usize; + let width = self.width as usize; + let channels = 3; + + let data = + Tensor::from_vec(data, &[height, width, channels], &Device::Cpu)?.permute((2, 0, 1))?; + + (data.to_dtype(DType::F32)? / 255.)? + .broadcast_sub(&mean)? + .broadcast_div(&std) + } + + pub fn load_images(&self, image_path: Vec<&str>) -> Result> { + let mut images: Vec = Vec::new(); + for path in image_path { + let img = image::io::Reader::open(path)?.decode().unwrap(); + images.push(img); + } + + Ok(images) + } +} diff --git a/candle-examples/examples/trocr/main.rs b/candle-examples/examples/trocr/main.rs new file mode 100644 index 00000000..e93d6b2f --- /dev/null +++ b/candle-examples/examples/trocr/main.rs @@ -0,0 +1,132 @@ +#[cfg(feature = "mkl")] +extern crate intel_mkl_src; + +#[cfg(feature = "accelerate")] +extern crate accelerate_src; + +use anyhow::Error as E; +use clap::{Parser, ValueEnum}; + +use candle::{DType, Tensor}; +use candle_examples::token_output_stream::TokenOutputStream; +use candle_nn::VarBuilder; +use candle_transformers::models::trocr; + +use tokenizers::Tokenizer; +mod image_processor; + +#[derive(Clone, Debug, Copy, ValueEnum)] +enum Which { + Base, + Large, +} + +#[derive(Parser, Debug)] +struct Args { + #[arg(long)] + model: Option, + + /// Choose the variant of the model to run. + #[arg(long, default_value = "base")] + which: Which, + + /// Run on CPU rather than on GPU. + #[arg(long)] + cpu: bool, + + /// Text to be translated + #[arg(long)] + image: String, +} + +pub fn main() -> anyhow::Result<()> { + use hf_hub::api::sync::Api; + let args = Args::parse(); + + let tokenizer_dec = { + let tokenizer = Api::new()? + .model(String::from("ToluClassics/candle-trocr-tokenizer")) + .get("tokenizer.json")?; + + Tokenizer::from_file(&tokenizer).map_err(E::msg)? + }; + + let mut tokenizer_dec = TokenOutputStream::new(tokenizer_dec); + + let device = candle_examples::device(args.cpu)?; + + let vb = { + let model = match args.model { + Some(model) => std::path::PathBuf::from(model), + None => match args.which { + Which::Base => Api::new()? + .repo(hf_hub::Repo::with_revision( + "microsoft/trocr-base-handwritten".to_string(), + hf_hub::RepoType::Model, + "refs/pr/3".to_string(), + )) + .get("model.safetensors")?, + Which::Large => Api::new()? + .repo(hf_hub::Repo::with_revision( + "microsoft/trocr-large-handwritten".to_string(), + hf_hub::RepoType::Model, + "refs/pr/6".to_string(), + )) + .get("model.safetensors")?, + }, + }; + println!("model: {:?}", model); + unsafe { VarBuilder::from_mmaped_safetensors(&[model], DType::F32, &device)? } + }; + + let encoder_config = match args.which { + Which::Base => candle_transformers::models::vit::Config::microsoft_trocr_base_handwritten(), + Which::Large => { + candle_transformers::models::vit::Config::microsoft_trocr_base_handwritten() + } + }; + + let decoder_config = trocr::TrOCRConfig::default(); + let mut model = trocr::TrOCRModel::new(&encoder_config, &decoder_config, vb)?; + + let config = image_processor::ProcessorConfig::default(); + let processor = image_processor::ViTImageProcessor::new(&config); + + let image = vec![args.image.as_str()]; + let image = processor.preprocess(image)?; + + let encoder_xs = model.encoder().forward(&image)?; + + let mut logits_processor = + candle_transformers::generation::LogitsProcessor::new(1337, None, None); + + let mut token_ids: Vec = vec![decoder_config.decoder_start_token_id]; + for index in 0..1000 { + let context_size = if index >= 1 { 1 } else { token_ids.len() }; + let start_pos = token_ids.len().saturating_sub(context_size); + let input_ids = Tensor::new(&token_ids[start_pos..], &device)?.unsqueeze(0)?; + + let logits = model.decode(&input_ids, &encoder_xs, start_pos)?; + + let logits = logits.squeeze(0)?; + let logits = logits.get(logits.dim(0)? - 1)?; + let token = logits_processor.sample(&logits)?; + token_ids.push(token); + + if let Some(t) = tokenizer_dec.next_token(token)? { + use std::io::Write; + print!("{t}"); + std::io::stdout().flush()?; + } + if token == decoder_config.eos_token_id { + break; + } + } + + if let Some(rest) = tokenizer_dec.decode_rest().map_err(E::msg)? { + print!("{rest}"); + } + println!(); + + Ok(()) +} diff --git a/candle-examples/examples/trocr/readme.md b/candle-examples/examples/trocr/readme.md new file mode 100644 index 00000000..329940f8 --- /dev/null +++ b/candle-examples/examples/trocr/readme.md @@ -0,0 +1,16 @@ +# candle-trocr + +`TrOCR` is a transformer OCR Model. In this example it is used to +transcribe image text. See the associated [model +card](https://huggingface.co/microsoft/trocr-base-printed) for details on +the model itself. + +## Running an example + +```bash +cargo run --example trocr --release -- --which base --cpu --image assets/trocr.png +``` + +``` + industry , Mr. Brown commented icily . " Let us have a +``` diff --git a/candle-transformers/src/models/mod.rs b/candle-transformers/src/models/mod.rs index 370b9108..3c025660 100644 --- a/candle-transformers/src/models/mod.rs +++ b/candle-transformers/src/models/mod.rs @@ -29,6 +29,7 @@ pub mod segment_anything; pub mod stable_diffusion; pub mod stable_lm; pub mod t5; +pub mod trocr; pub mod vgg; pub mod vit; pub mod whisper; diff --git a/candle-transformers/src/models/trocr.rs b/candle-transformers/src/models/trocr.rs new file mode 100644 index 00000000..785b06ca --- /dev/null +++ b/candle-transformers/src/models/trocr.rs @@ -0,0 +1,434 @@ +use crate::models::vit::{Config, Embeddings, Encoder}; +use candle::{Result, Tensor}; +use candle_nn::{ + embedding, layer_norm, linear_no_bias, Embedding, LayerNorm, Linear, Module, VarBuilder, +}; +use serde::Deserialize; + +#[derive(Debug, Clone, PartialEq, Deserialize)] +pub struct TrOCRConfig { + pub vocab_size: usize, + pub d_model: usize, + pub hidden_size: usize, + pub decoder_layers: usize, + pub decoder_attention_heads: usize, + pub decoder_ffn_dim: usize, + pub activation_function: candle_nn::Activation, + pub max_position_embeddings: usize, + pub dropout: f64, + pub attention_dropout: f64, + pub activation_dropout: f64, + pub decoder_start_token_id: u32, + pub init_std: f64, + pub decoder_layerdrop: f64, + pub use_cache: bool, + pub scale_embedding: bool, + pub use_learned_position_embeddings: bool, + pub layernorm_embedding: bool, + pub pad_token_id: usize, + pub bos_token_id: usize, + pub eos_token_id: u32, + pub num_attention_heads: usize, + pub decoder_vocab_size: Option, +} + +impl Default for TrOCRConfig { + fn default() -> Self { + Self { + vocab_size: 50265, + d_model: 1024, + hidden_size: 768, + decoder_layers: 12, + decoder_attention_heads: 16, + decoder_ffn_dim: 4096, + activation_function: candle_nn::Activation::Gelu, + max_position_embeddings: 512, + dropout: 0.1, + attention_dropout: 0.0, + activation_dropout: 0.0, + decoder_start_token_id: 2, + init_std: 0.02, + decoder_layerdrop: 0.0, + use_cache: true, + scale_embedding: false, + use_learned_position_embeddings: true, + layernorm_embedding: true, + pad_token_id: 1, + bos_token_id: 0, + eos_token_id: 2, + num_attention_heads: 12, + decoder_vocab_size: Some(50265), + } + } +} + +#[derive(Debug, Clone)] +struct TrOCRLearnedPositionalEmbedding { + offset: usize, + weights: Embedding, +} + +impl TrOCRLearnedPositionalEmbedding { + fn load(vb: VarBuilder, cfg: &TrOCRConfig) -> Result { + let offset: usize = 2; + let num_embeddings = cfg.max_position_embeddings; + let embedding_dim = cfg.d_model; + let weights = embedding(num_embeddings + offset, embedding_dim, vb)?; + + Ok(Self { offset, weights }) + } + + fn forward(&mut self, input_ids: &Tensor, past_key_values_length: u32) -> Result { + let (b_sz, seq_len) = input_ids.dims2()?; + + let mut positions = Tensor::arange( + past_key_values_length, + seq_len as u32 + past_key_values_length, + input_ids.device(), + )? + .expand((b_sz, seq_len))?; + + positions = + positions.broadcast_add(&Tensor::new(self.offset as u32, input_ids.device())?)?; + self.weights.forward(&positions) + } +} + +#[derive(Debug, Clone)] +struct TrOCRAttention { + head_dim: usize, + num_heads: usize, + is_decoder: bool, + scaling: f64, + k_proj: Linear, + v_proj: Linear, + q_proj: Linear, + out_proj: Linear, + kv_cache: Option<(Tensor, Tensor)>, +} + +impl TrOCRAttention { + fn load( + vb: VarBuilder, + cfg: &TrOCRConfig, + kdim: Option, + vdim: Option, + ) -> Result { + let embed_dim = cfg.d_model; + let num_heads = cfg.decoder_attention_heads; + let head_dim = embed_dim / num_heads; + let kdim = kdim.unwrap_or(embed_dim); + let vdim = vdim.unwrap_or(embed_dim); + + let k_proj = linear_no_bias(kdim, embed_dim, vb.pp("k_proj"))?; + let v_proj = linear_no_bias(vdim, embed_dim, vb.pp("v_proj"))?; + let q_proj = linear_no_bias(embed_dim, embed_dim, vb.pp("q_proj"))?; + + let out_proj = linear_no_bias(embed_dim, embed_dim, vb.pp("out_proj"))?; + Ok(Self { + head_dim, + num_heads, + is_decoder: true, + scaling: 1. / (head_dim as f64).sqrt(), + k_proj, + v_proj, + q_proj, + out_proj, + kv_cache: None, + }) + } + + fn _shape(&self, tensor: &Tensor, bsz: usize) -> Result { + tensor + .reshape((bsz, (), self.num_heads, self.head_dim))? + .transpose(1, 2)? + .contiguous() + } + + fn forward( + &mut self, + xs: &Tensor, + kv_states: Option<&Tensor>, + attn_mask: Option<&Tensor>, + ) -> Result { + let (b_sz, tgt_len, _) = xs.dims3()?; + let query_states = (xs.apply(&self.q_proj)? * self.scaling)?; + let (key_states, value_states) = match kv_states { + None => { + let key_states = self._shape(&xs.apply(&self.k_proj)?, b_sz)?; + let value_states = self._shape(&xs.apply(&self.v_proj)?, b_sz)?; + if self.is_decoder { + let kv_states = match &self.kv_cache { + None => (key_states, value_states), + Some((p_key_states, p_value_states)) => { + let key_states = Tensor::cat(&[p_key_states, &key_states], 2)?; + let value_states = Tensor::cat(&[p_value_states, &value_states], 2)?; + (key_states, value_states) + } + }; + self.kv_cache = Some(kv_states.clone()); + kv_states + } else { + (key_states, value_states) + } + } + Some(kv_states) => { + let key_states = self._shape(&kv_states.apply(&self.k_proj)?, b_sz)?; + let value_states = self._shape(&kv_states.apply(&self.v_proj)?, b_sz)?; + (key_states, value_states) + } + }; + let proj_shape = (b_sz * self.num_heads, (), self.head_dim); + let query_states = self._shape(&query_states, b_sz)?.reshape(proj_shape)?; + let key_states = key_states.reshape(proj_shape)?; + let value_states = value_states.reshape(proj_shape)?; + let attn_weights = query_states.matmul(&key_states.transpose(1, 2)?)?; + let attn_weights = match attn_mask { + None => attn_weights, + Some(attn_mask) => attn_weights.broadcast_add(attn_mask)?, + }; + let attn_probs = candle_nn::ops::softmax_last_dim(&attn_weights)?; + let attn_output = attn_probs.matmul(&value_states)?; + attn_output + .reshape((b_sz, self.num_heads, tgt_len, self.head_dim))? + .transpose(1, 2)? + .reshape((b_sz, tgt_len, self.head_dim * self.num_heads))? + .apply(&self.out_proj) + } +} + +#[derive(Debug, Clone)] +struct TrOCRDecoderLayer { + self_attn: TrOCRAttention, + activation_fn: candle_nn::Activation, + self_attn_layer_norm: LayerNorm, + encoder_attn: TrOCRAttention, + encoder_attn_layer_norm: LayerNorm, + fc1: Linear, + fc2: Linear, + final_layer_norm: LayerNorm, +} + +impl TrOCRDecoderLayer { + fn load(vb: VarBuilder, cfg: &TrOCRConfig) -> Result { + let embed_dim = cfg.d_model; + let self_attn = TrOCRAttention::load(vb.pp("self_attn"), cfg, None, None)?; + let self_attn_layer_norm = layer_norm(embed_dim, 1e-5, vb.pp("self_attn_layer_norm"))?; + let encoder_attn = TrOCRAttention::load( + vb.pp("encoder_attn"), + cfg, + Some(cfg.hidden_size), + Some(cfg.hidden_size), + )?; + let encoder_attn_layer_norm = + layer_norm(embed_dim, 1e-5, vb.pp("encoder_attn_layer_norm"))?; + let fc1 = linear_no_bias(embed_dim, cfg.decoder_ffn_dim, vb.pp("fc1"))?; + let fc2 = linear_no_bias(cfg.decoder_ffn_dim, embed_dim, vb.pp("fc2"))?; + let final_layer_norm = layer_norm(embed_dim, 1e-5, vb.pp("final_layer_norm"))?; + let activation_fn = candle_nn::Activation::Gelu; + + Ok(Self { + self_attn, + activation_fn, + self_attn_layer_norm, + encoder_attn, + encoder_attn_layer_norm, + fc1, + fc2, + final_layer_norm, + }) + } + + fn forward( + &mut self, + xs: &Tensor, + attention_mask: &Tensor, + encoder_hidden_states: Option<&Tensor>, + ) -> Result { + let residual = xs.clone(); + let xs = self.self_attn.forward(xs, None, Some(attention_mask))?; + let xs = (xs + residual)?; + let mut xs = self.self_attn_layer_norm.forward(&xs)?; + + if let Some(encoder_hidden_states) = &encoder_hidden_states { + let residual = xs.clone(); + let encoder_attention_mask = attention_mask.clone(); // TODO + xs = self.encoder_attn.forward( + &xs, + Some(encoder_hidden_states), + Some(&encoder_attention_mask), + )?; + xs = (xs + residual)?; + xs = self.encoder_attn_layer_norm.forward(&xs)? + } + + let residual = xs.clone(); + let xs = self.fc1.forward(&xs)?; + let xs = self.activation_fn.forward(&xs)?; + let xs = self.fc2.forward(&xs)?; + let xs = (xs + residual)?; + let xs = self.final_layer_norm.forward(&xs)?; + + Ok(xs) + } +} + +#[derive(Debug, Clone)] +pub struct TrOCRDecoder { + layers: Vec, + embed_scale: Option, + embed_tokens: Embedding, + embed_positions: TrOCRLearnedPositionalEmbedding, +} + +impl TrOCRDecoder { + fn new(cfg: &TrOCRConfig, vb: VarBuilder) -> Result { + let vb = vb.pp("decoder.model.decoder"); + + let embed_tokens = embedding(cfg.vocab_size, cfg.d_model, vb.pp("embed_tokens"))?; + let embed_positions = TrOCRLearnedPositionalEmbedding::load(vb.pp("embed_positions"), cfg)?; + let mut layers = Vec::with_capacity(cfg.decoder_layers); + let vb_l = vb.pp("layers"); + for idx in 0..cfg.decoder_layers { + let layer = TrOCRDecoderLayer::load(vb_l.pp(idx), cfg)?; + layers.push(layer) + } + let embed_scale = if cfg.scale_embedding { + Some((cfg.d_model as f64).sqrt()) + } else { + None + }; + + Ok(Self { + layers, + embed_scale, + embed_tokens, + embed_positions, + }) + } + + pub fn forward( + &mut self, + xs: &Tensor, + encoder_xs: Option<&Tensor>, + past_kv_len: usize, + attn_mask: &Tensor, + ) -> Result { + let embed_pos = self.embed_positions.forward(xs, past_kv_len as u32)?; + let xs = xs.apply(&self.embed_tokens)?; + + let xs = match self.embed_scale { + None => xs, + Some(scale) => (xs * scale)?, + }; + + let mut xs = xs.broadcast_add(&embed_pos)?; + + for layer in self.layers.iter_mut() { + xs = layer.forward(&xs, attn_mask, encoder_xs)?; + } + Ok(xs) + } +} + +#[derive(Debug, Clone)] +pub struct TrOCREncoder { + embeddings: Embeddings, + encoder: Encoder, + layernorm: LayerNorm, +} + +impl TrOCREncoder { + pub fn new(cfg: &Config, vb: VarBuilder) -> Result { + let vb_v = vb.pp("encoder"); + + let embeddings = Embeddings::new(cfg, false, vb_v.pp("embeddings"))?; + + let encoder = Encoder::new(cfg, vb_v.pp("encoder"))?; + let layernorm = layer_norm(cfg.hidden_size, cfg.layer_norm_eps, vb_v.pp("layernorm"))?; + + Ok(Self { + embeddings, + encoder, + layernorm, + }) + } + + pub fn forward(&self, xs: &Tensor) -> Result { + let embedding_output = self.embeddings.forward(xs, None, false)?; + let encoder_outputs = self.encoder.forward(&embedding_output)?; + + self.layernorm.forward(&encoder_outputs) + } +} + +#[derive(Debug, Clone)] +pub struct TrOCRForCausalLM { + decoder: TrOCRDecoder, + output_projection: Linear, +} + +impl TrOCRForCausalLM { + pub fn new(decoder_cfg: &TrOCRConfig, vb: VarBuilder) -> Result { + let decoder = TrOCRDecoder::new(decoder_cfg, vb.clone())?; + let output_projection = + candle_nn::Linear::new(decoder.embed_tokens.embeddings().clone(), None); + Ok(Self { + decoder, + output_projection, + }) + } + + pub fn forward( + &mut self, + xs: &Tensor, + encoder_xs: Option<&Tensor>, + past_kv_len: usize, + attn_mask: &Tensor, + ) -> Result { + let xs = self + .decoder + .forward(xs, encoder_xs, past_kv_len, attn_mask)?; + let xs = xs.apply(&self.output_projection)?; + + Ok(xs) + } +} + +#[derive(Debug, Clone)] +pub struct TrOCRModel { + encoder: TrOCREncoder, + decoder: TrOCRForCausalLM, +} + +impl TrOCRModel { + pub fn new(encoder_cfg: &Config, decoder_cfg: &TrOCRConfig, vb: VarBuilder) -> Result { + let encoder = TrOCREncoder::new(encoder_cfg, vb.clone())?; + let decoder = TrOCRForCausalLM::new(decoder_cfg, vb)?; + Ok(Self { encoder, decoder }) + } + + pub fn encoder(&mut self) -> &mut TrOCREncoder { + &mut self.encoder + } + + pub fn decoder(&mut self) -> &mut TrOCRForCausalLM { + &mut self.decoder + } + + pub fn decode( + &mut self, + xs: &Tensor, + encoder_xs: &Tensor, + past_kv_len: usize, + ) -> Result { + let seq_len = xs.dim(1)?; + let mask: Vec<_> = (0..seq_len) + .flat_map(|i| (0..seq_len).map(move |j| if j > i { f32::NEG_INFINITY } else { 0f32 })) + .collect(); + let mask = Tensor::from_vec(mask, (seq_len, seq_len), xs.device())?; + + self.decoder + .forward(xs, Some(encoder_xs), past_kv_len, &mask) + } +} diff --git a/candle-transformers/src/models/vit.rs b/candle-transformers/src/models/vit.rs index e2218c54..962528c1 100644 --- a/candle-transformers/src/models/vit.rs +++ b/candle-transformers/src/models/vit.rs @@ -6,16 +6,16 @@ use candle_nn::{layer_norm, LayerNorm, VarBuilder}; // https://github.com/huggingface/transformers/blob/main/src/transformers/models/vit/configuration_vit.py #[derive(Debug, Clone)] pub struct Config { - hidden_size: usize, - num_hidden_layers: usize, - num_attention_heads: usize, - intermediate_size: usize, - hidden_act: candle_nn::Activation, - layer_norm_eps: f64, - image_size: usize, - patch_size: usize, - num_channels: usize, - qkv_bias: bool, + pub hidden_size: usize, + pub num_hidden_layers: usize, + pub num_attention_heads: usize, + pub intermediate_size: usize, + pub hidden_act: candle_nn::Activation, + pub layer_norm_eps: f64, + pub image_size: usize, + pub patch_size: usize, + pub num_channels: usize, + pub qkv_bias: bool, } impl Config { @@ -34,6 +34,21 @@ impl Config { qkv_bias: true, } } + + pub fn microsoft_trocr_base_handwritten() -> Self { + Self { + hidden_size: 768, + num_hidden_layers: 12, + num_attention_heads: 12, + intermediate_size: 3072, + hidden_act: candle_nn::Activation::Gelu, + layer_norm_eps: 1e-12, + image_size: 384, + patch_size: 16, + num_channels: 3, + qkv_bias: false, + } + } } #[derive(Debug, Clone)] @@ -76,7 +91,7 @@ impl Module for PatchEmbeddings { } #[derive(Debug, Clone)] -struct Embeddings { +pub struct Embeddings { cls_token: Tensor, mask_token: Option, patch_embeddings: PatchEmbeddings, @@ -85,7 +100,7 @@ struct Embeddings { } impl Embeddings { - fn new(cfg: &Config, use_mask_token: bool, vb: VarBuilder) -> Result { + pub fn new(cfg: &Config, use_mask_token: bool, vb: VarBuilder) -> Result { let hidden_size = cfg.hidden_size; let cls_token = vb.get((1, 1, hidden_size), "cls_token")?; let mask_token = if use_mask_token { @@ -115,7 +130,7 @@ impl Embeddings { todo!() } - fn forward( + pub fn forward( &self, pixel_values: &Tensor, bool_masked_pos: Option<&Tensor>, @@ -324,12 +339,12 @@ impl Module for Layer { } #[derive(Debug, Clone)] -struct Encoder { +pub struct Encoder { layers: Vec, } impl Encoder { - fn new(cfg: &Config, vb: VarBuilder) -> Result { + pub fn new(cfg: &Config, vb: VarBuilder) -> Result { let vb = vb.pp("layer"); let mut layers = Vec::with_capacity(cfg.num_hidden_layers); for i in 0..cfg.num_hidden_layers {