From ee9ed342877423bcce2040d181a52674c297cf56 Mon Sep 17 00:00:00 2001 From: NunoSempere Date: Sun, 16 Jul 2023 21:26:33 +0200 Subject: [PATCH] move some functions from scratchpad => squiggle.c, reorg --- scratchpad/makefile | 2 +- scratchpad/scratchpad | Bin 22056 -> 22456 bytes scratchpad/scratchpad.c | 242 ++-------------------------------------- squiggle.c | 187 +++++++++++++++++++++++++++++++ squiggle.h | 34 +++++- 5 files changed, 232 insertions(+), 233 deletions(-) diff --git a/scratchpad/makefile b/scratchpad/makefile index cf6d6b5..fb1f8b2 100644 --- a/scratchpad/makefile +++ b/scratchpad/makefile @@ -9,7 +9,7 @@ CC=gcc # required for nested functions # CC=tcc # <= faster compilation # Main file -SRC=scratchpad.c +SRC=scratchpad.c ../squiggle.c OUTPUT=./scratchpad ## Dependencies diff --git a/scratchpad/scratchpad b/scratchpad/scratchpad index 51d92ed6da4d1b9e74bd3543b57b6f95aa528bd8..a689f5e709642f87afdaa537de697140f84d102d 100755 GIT binary patch delta 7687 zcmbVR3v?7!ny%_jC*65;btma`I+CP?9w#6s>F`J*2D?LmYG^b_G6q3H1On^`;gLX` z!=Yg^AT0#LULE6OT-IS`aK?2;aWuNa+6jt*6*jo6!>TBN9h&2JeUKZ!Btso7 z$zaK1uD>--9qLf9sbdi!7Qg%wb(}xzMtI7gj-za|=wu7UboNBKjYY(^k)2+ZG8`t7 zO0-0gJ!QD-<{whrr77-W7Grhl9(#O+($N_AMzQ}7VrJRJSL5m?u?@x?hXa~)9|d7E z*X&@g8-Fam%04wZ#V`2tADBHRPh`Bml0Tp4&sMfNCXd--(#7NKpJS@UKXUA5{#;8> zwkjr_Ei~neZ?j#d4T%rdEv*%ntii8w6_aA06bsq2u|EnK1fviG(?-FBU#t+vp&*!n zH3|m)6o55y_sHQG?lk4N9KWVrzeJABLNEG+3-fWDa7i0K%*Q4{M5IJ-$XIR>phHV# z2+EAYdW{)CNL{`>iRLZ+1_2`%AO-aj->nX)d{S7F*_lL>rMD{PH9K7kH6E%yZNfRjGFD@DGTEG% zWYPry8`&;k8-dlTmD9Gdc3}6K6J3aDu(b}+9M6uxIx;fY_OhETPG+_wCP{-hL1<&M zfi1Emy6i)+okKJwv->ShF?HW&i!t^eZ6nPi?-YbSE&Wo~ zgm{3q$D&;FsqSugWpeCxQ9zyk~s9$qX@&%!$ zL-0*n0?}EB9{jJOvY`Aa5Z@7^D115mwjAz=?gV2<4)=iN6_sTh$v2W=jyK;Iv079N*dbj5DJby)h|n9B5j20 zmll{PqJC-YeUWGzG^qwZMt#E5o5{$_+>LY5o@K*cSQx^5NMCvWdDJgpsj0*=e2>xQuKcRqKNf{x2 z&C*k|Mt?mCEDg6-%1tHjQCALsUY{s$FL@t>=-;j)^2IQD&~$`Vr%uS)J%XaRza2Vx z`5^2=QeNl!q<+bk)Wlv+9Vb4)zDWIz_&nR;7%$edcE<*6ZIjb(O>4TMD1DMG=t=LB z>}4KXt(4OqUCBae8HRCJ*wbmF4Av_wl2)F22QijOw#6Q|H`0%eDLI3AoMXixo9%Rp z9=6=MR-DgHI-e8wvrXv}#e?j%^f6h1f4fNY7yjgRg286fBqxES9|J@eFeQD>#4U*L z4VnJiQ|NR?9nagKP4u7zqdkRJG9ALk5jTh&#k6fd&0@tO%4+S<= zK^j2R$1*dwqb zH>ENdKFx~=c``i>wpwXr7YeElyO)1{8|_m4(tU@!eejd>fgj}8LH0(XV{~kk=vdse z*vR1R*A4zP;fD?0dflLc&3354u1hx<45$X*(6LdX({AFw*BkNwB-lJ*DR~^K41}-w zD+1w*aB?Crm2&vM>}T*WIC_AuFar(ETcP#DsyeV|Xon7FJ$*ZD+Z-dgLTm?Sn%rVoxp{N+-ScX^$JsOPj8NkGQkX?bX%-%XDjsh?cy5@5#2aP-j#tHaFbTGK z-9=!HGCX=;aTg4(llU@wuaFU_sQV_^1HpqAh(~6k;TirfFxT8QyD>G zs@kQTW75H6mFOb=W0C`-Lr7Cq7xF_$xwxyLD}27`)K%Z5gU0#^@*?D;D&~pia#=4*e2KYYAWoAC54?@2o!!6DUL zs^-FiZ67Is3m?ki^OT8@l=Gjm*mqr{WW&x8v0(+k^kFHdOFGy^N#R`wMiEL1jR1~O zQjU~`Pa!M*OK%zLy_6NdR8{e~fV}vNk4rm~=_FN?qZ0hZZ*8<7KQ!7B5@s#BiumSx z5@$5uv7a^$j4C|l=<&0tQom%I1%wR<+`P~MAL@}x$?V@AJGecyNA0vrwVy)u`I{>| z-|~G~Mh;s@N7;TTS%6BWina&$BgA)*%)d6A4*A2eG$Y}Nw*AjF*>L{_=K=4f-XvN< za??RXa&e;*=7WD0a79^nCK&!kN$u}VBs%(~O2w4{T;!k}l^p%qSA2nK`@}rlztqD+ z-L)wD$hsPL=(~ZRB&6an&F6Fg#!lxCkaXH#f*)HfdrO&(>BA(1lD%b12cY!qko+pi9aP8%wacWXdR<^;nMY4=%^6m2&fK8iF3DX4%7;?G2Z6i~r9-}!NA98in~D;`JcKrathb1GlW`KgV!MH< z`dCqX=r8+3QI-KV0&XIBN>PgNdb9UaMajXq^|Ye20=5CZ02n>3gp>$6Isc<5CjdLY zR1`OMlV!+Q1>j!5Wq|43it-Dv+g3UC`>n~6PAaEsW;UMg_L+DzLr2R1 zvX!YJ3AY=+7-a5N4M*~?;O7RefN-+aw!~@{%Fun6hQL1wUj@I_)?`@0>c$ltc7M)( zKdvAx@e4&M0fkyRe)D0z=nMAk__Sanx;sE#MYL&F+YZBqc&quY1R2A}KsyE6>l*D* zLzy*wr!io4H^)r1=Ik)dwB~P(4O)wu;sREmVHR6B!BHqD2qb$7vK5%l|BI{*vST4T z7E@b*TLs)6;PQsxsAeL-P1JA{dn<5f zfa@TfY;_x6M7&cIxZYPl+l`u|*Gy`}AL<11T_8^oh6Fkx;7{Ph&p>NIQGPU(24!v< zqtGf)sM*MbFqP(Hy47ur72~ZoKUt(eJm4uuu|7}yXoRB`a|XEa93$Mz(-{OZ9aS8F zL7JvV4W%>{QwOJ_jHbd*Q!(9IZu~i{6#1?Berryt)eY8Cp1^4dc)b={p=-oh);I3< z*b~$}&kDxhS5Sn~38L)Bsb+b(27zlFf+G@@+hX8%_c6o7Qu7Pw(rRHBd@1`{EW1n= z{ri}W1xm8T^=x@bw!LbMs@0_eA)h@7S`mA_Bu_lSy3x;NmdSZ`pY}P4qNKKo1tw?P zx2ZFakA!OLWDiZg+fWq84o}W9#Ky6n$+e=7&A#(tiw(C&YVLiBg-h2P&ZjcNlw5H)^G*DU4*&)Ac6jNfjO|)TtUIYUl_bI0XGOr_S>+ zI-&q~hiU#&;~#GEGRLO~`loJPPZE}eh()ANxsD%&k~9d3C>UA{L)s_^Bm-Hl)e^Q} z-o~-^GN*Vqf7Z~`v<94R_L3pRyHk_dJzC?{Watoa3#i6&4WZ+?mY3=2I-tqg@I{Fl z-^l<^vdvSSuFo{8TiderE#@4_dJWJ|zUFH?l)jX zhej{f@aHv8^et{GNx3R@TiTx%`bNKisA3)G*{ShtqC*5tKU1D%tEV|le*l%|OMHX4 zYmQ)SH@F1-OSUQ_2p?!#7HV2FH-#SJN4!G5%84z3x4&up!bo+j?Px+Q+bla>E=)A# z=#(}$OEjs`oLZt$%WqQAdiMXD)l7HR?9nvqf9B{JnX5_VyHzQ^_oBNT4TwrxGFQJ>K8tA;}k(`l_{x$71czeS602KYyE-L;zJ!eQ%e9pQuTzSMGw|3`som+ z+zA3v>sD0b(p!rwFUy#j=^uIl=EC{a^>x)=0e?~rku4S$Ev~J8U`c%q5pFcSa_08L znl)=`9;#l{uwre)3N|pa!lGHKuVXWVKfPTWY4_~%fbIx%X}Gn^R}L`;S63`=tZ!Je zgdGYNTU6~(y-%DKXEfLhEWSL~enQ*VbePKm<#~4f4AbFQ76w*7!F1Td4wdIw^z*Be ziL(m_4^RCN#gRHYWVUG~$X3sGvM*=bE&1BXna{*~3gdm+$Y*EA?n+}1&6e2Gd+f=2 z0qf9_&K|iZH(uY(aOid1GYyBHWlmnQzTfK*e`(H`lbfucg*uGHut(xQ5OE>P!NTANaK@Q0QuDlA>1YBl!5$ ztg=Dp$t~CC3Ei!=wpz;`AGNIp+{#?rw5>KjO0=hX%`ZMIjWyr>-S@*qYX8~G^WJmM z@0@$iJ?Gwg-gh|g7h%s!LW@1FC7#L6qY$>ZvNh-Vs3xPrHW^z_ykRJ>*uPf%KRHvtx{Q=qiVJ zR0|%}f<{BCpAZ@K)d}9gyFYxVX>wm>{=Q$g42|6K_v*@zX{|oX%EJ@zC~ICX+m5G$ z`n!ae=$zgroS|J|d6X4q7lJbXEY%wFs15v4IuTYaJSU@TT}VXj&#C z6_(Yj+w4;_t>G@}jwp4gCqmX0rzlf+QlqFM(&mU&ZE$jY#^-@Ap7z&ONeb z{)UUfsvY78h`GJD`LV_NZX|bRjh`#vf{6BlG3rvNWSDb zavY%tx&zVla|Fhsy5?P_?lzPGAZ^!AX(GzVeDE94EQndH5*De z8FQn->f36#H5TJdp$4PWX-;jz8rWf#Lf;|lkBXzy@yWs`bR*uG{R#Bhs^L|>4aKkV zp)dGV!yw<*;{9O6gEx`FI=Fn-@G^~080FlS%!8TV4x93M<4G~+9a$68YqrEJrzaB% zgjw`%!gIn?wA4CMsG(ifwZar~CJs*=`fr{~OL>O9#~fRhVOeC(Y8OpZotUKivX6e9 zI8^smAGIY063+BNvCM3#%SdyFdNDBEa!{Dvc(* zBX689cxu6gQXc(#*J+Dmq%eX0>KN**`|oV%!$6l+Ht;=&$mmB$jxda}lAOXJnv|3h zy$5zL3`P3Psq{pKHEJz*o~Vt9^h(A$U1TEV4trd;&PoS{-FL4OHde<%Jj#xdIj`5e zix8r5ai>Hc_)j%FrB-Gj;ro7S=f zx{{fjkSw#aPn$PSL>{;^Qv$i}+Nygrp5DA`x9(&d)n++#$pWp-TAdvU9oelMHjKMN zTIPCj9^v!{5K-SlgR@i9c`d!T|Mq$p59j>z;@}=!l0~`Z)QyaeC0B-TM?pT^LN(c0 z#>E&0!!KOENv+w1#w>+f)(0-0>Qdr#eHMBorQF!1T*7k}`XD7UdOcjq_lb4WtrVwj zZx=b8i6~Hr#*hP{KX}s@ye_W!L6SCwRu%Z##Wa3QarNYTiZf^!4z)$*hmPaUo<(W$ z&O3k_!mS>?P!Dgio^n;QaXZ~1aju4nT+S=4;_nc|vcJab;v^jK|K;!v=&25u4{*pd z(~84`5X0fOSJh>RTrNADc4;t(%_UEEGELL7NRUVnhsw!`B0SMIEB?RbHg zXLrf9jhD3!6MZ;1sWAzrmCZ#0v2X-Y9=swaP`m{bFlOE~A6MYw(`w1^4vgb0EZ>ep z3*$wA{leG>Q(`5$0bvOJ;$UhV8 z>X$BDQT@>J70YUuR=cX*@O#?q_uzwLd`-#jn4>-ZhZ);h0{(4?#!FHoszKvA@Oh`r z?|bNOcOi1bScrGLQH|TxC=Aw*<)5i`JT+fP690}1oRy?3&>Bz=sHans!tm-nehx{6 zUg*3e%?92Hx)jv%r6e_h20)(%?fOQNIzd}RN#d{I}qg6;(kfTrPSF9o%9OVR;QeAh@tSXe!1Qy6`epXuzx zTJs=bWf8cujAh^%iD&lbaKR|CAbn06!jxrPj)oG)8t~YGcYukeQgRt@9l^e&} z@GJ#3Tt<}nuVef;wHDZfz4SuCBwgg!)Kie3*zmO^6{~7^m9lxrO54}8dh{>Cck-b~ z`$x}hyourQZzTyI4RYg7EJ(*0U~(*+$CCl<2CxEIOTmhOWgxE_#`m!tSO8cD*Ylau zbgx92A}2)4d(876c1e=zxoo1zvQ4+P5z@+NK5-+2v+oqjG9P4Zki7)iWp&~cx_eFb z-|9FdZ9Q^ z*h42VcF+&Sd9i^hiUMMlTY$!lOO18mla@FB#m`$kEg3gO*A-5O#yNEz;dE-;A|XKD z@muzNHr_74V7JHS?&$wZLYaqOOh(rAvv?GY4)I#`(ac3ty^{nBy$a6E(m-uVEKW7h zkLXz}E6Skr-drI>u_ZP~OMXA;ZHHFG>MgM$R&R+_Z({*^7+ipYC2MpRE9px4ba$!E z(S=Wo{*7M&9geVzYD@Dly|vVqq@`BVdkO=A!9@j~D$R`I7fQg@oH3fB$}$@}WHl?( ze$Z*N{wb?rDHR%`&@(dO*6zfw%4|H-eyM9x9S)Cx3!-hBhJT9P)3HH0%4O3~P&+2X z@sQNhU{GaQ)Lmv1%H)2g-2aM(Ot3jNsQO*0ilTO3YiQ~O8@_%q%kd{w=XkoJBVP{; z+v!#49FcX>xsHx+m+DX@r)BAN)k2f%MXMfqdU&GE9gkq-HK!?)l%z72sp5eo1=lK8 zo;<5kmA3IZiKdmk1;8!*76cSr>mMXxl1+Gw&lQ@KTW1TA1{sE*#JqUN2&#FLmCTF8h^E;g(-wR%B8|@c%QP9>*yRmS>hhxxZk??lw{S%as2Y)ZIkM|19&>u zaD1U^>>Y3vuW%2>TiSdYbMIk_nA}U7DiTxfTmA^EURYP#P|NOLy8QkJJ)TMB{*uz_ z{M`K9Q8c?eh1QoR5B`~tl>6%zRo}O;p_cZSZ>7qL>GbD{2jT;MwsOfsciNazIcDFJ zmEn3_r$9TV=Em+(Z-0iiJ;jULfcQdKS;sj|@QDwk2Kc&M6Hg+{H; z1?WUo;q6yUV7#)fz2ulF#o)ytH-63XXxOabMvq@XVrb5+!l(+hCMF3qH{kpa^rmff diff --git a/scratchpad/scratchpad.c b/scratchpad/scratchpad.c index 3a91fa6..a7a6506 100644 --- a/scratchpad/scratchpad.c +++ b/scratchpad/scratchpad.c @@ -1,33 +1,13 @@ -#include // FLT_MAX, FLT_MIN -#include // INT_MAX #include // erf, sqrt #include #include #include -// #include #include +#include "../squiggle.h" -#define EXIT_ON_ERROR 0 -#define MAX_ERROR_LENGTH 500 -#define PROCESS_ERROR(...) \ - do { \ - if (EXIT_ON_ERROR) { \ - printf("@, in %s (%d)", __FILE__, __LINE__); \ - exit(1); \ - } else { \ - char error_msg[MAX_ERROR_LENGTH]; \ - snprintf(error_msg, MAX_ERROR_LENGTH, "@, in %s (%d)", __FILE__, __LINE__); \ - struct box error = { .empty = 1, .error_msg = error_msg }; \ - return error; \ - } \ - } while (0) #define NUM_SAMPLES 1000000 - -struct box { - int empty; - float content; - char* error_msg; -}; +#define STOP_BETA 1.0e-8 +#define TINY_BETA 1.0e-30 // Example cdf float cdf_uniform_0_1(float x) @@ -59,16 +39,10 @@ float cdf_normal_0_1(float x) return 0.5 * (1 + erf((x - mean) / (std * sqrt(2)))); // erf from math.h } -// [x] to do: add beta. -// [x] for the cdf, use this incomplete beta function implementation, based on continuous fractions: -// -// - -#define STOP_BETA 1.0e-8 -#define TINY_BETA 1.0e-30 struct box incbeta(float a, float b, float x) { // Descended from , + // // but modified to return a box struct and floats instead of doubles. // [ ] to do: add attribution in README // Original code under this license: @@ -174,200 +148,6 @@ struct box cdf_beta(float x) } } -// Inverse cdf at point -// Two versions of this function: -// - raw, dealing with cdfs that return floats -// - box, dealing with cdfs that return a box. - -// Inverse cdf -struct box inverse_cdf_float(float cdf(float), float p) -{ - // given a cdf: [-Inf, Inf] => [0,1] - // returns a box with either - // x such that cdf(x) = p - // or an error - // if EXIT_ON_ERROR is set to 1, it exits instead of providing an error - - float low = -1.0; - float high = 1.0; - - // 1. Make sure that cdf(low) < p < cdf(high) - int interval_found = 0; - while ((!interval_found) && (low > -FLT_MAX / 4) && (high < FLT_MAX / 4)) { - // ^ Using FLT_MIN and FLT_MAX is overkill - // but it's also the *correct* thing to do. - - int low_condition = (cdf(low) < p); - int high_condition = (p < cdf(high)); - if (low_condition && high_condition) { - interval_found = 1; - } else if (!low_condition) { - low = low * 2; - } else if (!high_condition) { - high = high * 2; - } - } - - if (!interval_found) { - PROCESS_ERROR("Interval containing the target value not found, in function inverse_cdf"); - } else { - - int convergence_condition = 0; - int count = 0; - while (!convergence_condition && (count < (INT_MAX / 2))) { - float mid = (high + low) / 2; - int mid_not_new = (mid == low) || (mid == high); - // float width = high - low; - // if ((width < 1e-8) || mid_not_new){ - if (mid_not_new) { - convergence_condition = 1; - } else { - float mid_sign = cdf(mid) - p; - if (mid_sign < 0) { - low = mid; - } else if (mid_sign > 0) { - high = mid; - } else if (mid_sign == 0) { - low = mid; - high = mid; - } - } - } - - if (convergence_condition) { - struct box result = { .empty = 0, .content = low }; - return result; - } else { - PROCESS_ERROR("Search process did not converge, in function inverse_cdf"); - } - } -} - -struct box inverse_cdf_box(struct box cdf_box(float), float p) -{ - // given a cdf: [-Inf, Inf] => Box([0,1]) - // returns a box with either - // x such that cdf(x) = p - // or an error - // if EXIT_ON_ERROR is set to 1, it exits instead of providing an error - - float low = -1.0; - float high = 1.0; - - // 1. Make sure that cdf(low) < p < cdf(high) - int interval_found = 0; - while ((!interval_found) && (low > -FLT_MAX / 4) && (high < FLT_MAX / 4)) { - // ^ Using FLT_MIN and FLT_MAX is overkill - // but it's also the *correct* thing to do. - struct box cdf_low = cdf_box(low); - if (cdf_low.empty) { - PROCESS_ERROR(cdf_low.error_msg); - } - - struct box cdf_high = cdf_box(high); - if (cdf_high.empty) { - PROCESS_ERROR(cdf_low.error_msg); - } - - int low_condition = (cdf_low.content < p); - int high_condition = (p < cdf_high.content); - if (low_condition && high_condition) { - interval_found = 1; - } else if (!low_condition) { - low = low * 2; - } else if (!high_condition) { - high = high * 2; - } - } - - if (!interval_found) { - PROCESS_ERROR("Interval containing the target value not found, in function inverse_cdf"); - } else { - - int convergence_condition = 0; - int count = 0; - while (!convergence_condition && (count < (INT_MAX / 2))) { - float mid = (high + low) / 2; - int mid_not_new = (mid == low) || (mid == high); - // float width = high - low; - if (mid_not_new) { - // if ((width < 1e-8) || mid_not_new){ - convergence_condition = 1; - } else { - struct box cdf_mid = cdf_box(mid); - if (cdf_mid.empty) { - PROCESS_ERROR(cdf_mid.error_msg); - } - float mid_sign = cdf_mid.content - p; - if (mid_sign < 0) { - low = mid; - } else if (mid_sign > 0) { - high = mid; - } else if (mid_sign == 0) { - low = mid; - high = mid; - } - } - } - - if (convergence_condition) { - struct box result = { .empty = 0, .content = low }; - return result; - } else { - PROCESS_ERROR("Search process did not converge, in function inverse_cdf"); - } - } -} - -// Some randomness functions for: -// - Sampling from a cdf -// - Benchmarking against a previous approach, which will be faster, but less general - -// Get random number between 0 and 1 -uint32_t xorshift32(uint32_t* seed) -{ - // Algorithm "xor" from p. 4 of Marsaglia, "Xorshift RNGs" - // See - // https://en.wikipedia.org/wiki/Xorshift - // Also some drama: , - - uint32_t x = *seed; - x ^= x << 13; - x ^= x >> 17; - x ^= x << 5; - return *seed = x; -} - -// Distribution & sampling functions -float rand_0_to_1(uint32_t* seed) -{ - return ((float)xorshift32(seed)) / ((float)UINT32_MAX); -} - -// Sampler based on inverse cdf and randomness function -struct box sampler_box_cdf(struct box cdf(float), uint32_t* seed) -{ - float p = rand_0_to_1(seed); - struct box result = inverse_cdf_box(cdf, p); - return result; -} -struct box sampler_float_cdf(float cdf(float), uint32_t* seed) -{ - float p = rand_0_to_1(seed); - struct box result = inverse_cdf_float(cdf, p); - return result; -} - -// Comparison point with raw normal sampler -const float PI = 3.14159265358979323846; -float sampler_normal_0_1(uint32_t* seed) -{ - float u1 = rand_0_to_1(seed); - float u2 = rand_0_to_1(seed); - float z = sqrtf(-2.0 * log(u1)) * sin(2 * PI * u2); - return z; -} - // Some testers void test_inverse_cdf_float(char* cdf_name, float cdf_float(float)) { @@ -445,12 +225,12 @@ int main() test_and_time_sampler_float("cdf_normal_0_1", cdf_normal_0_1, seed); // Get some normal samples using a previous approach - printf("\nGetting some samples from sampler_normal_0_1\n"); + printf("\nGetting some samples from unit_normal\n"); clock_t begin_2 = clock(); for (int i = 0; i < NUM_SAMPLES; i++) { - float normal_sample = sampler_normal_0_1(seed); + float normal_sample = unit_normal(seed); // printf("%f\n", normal_sample); } @@ -460,11 +240,11 @@ int main() // Test box sampler test_and_time_sampler_box("cdf_beta", cdf_beta, seed); - // Ok, this is slower than python!! - // Partly this is because I am using a more general algorithm, - // which applies to any cdf - // But I am also using really anal convergence conditions. - // This could be optimized. + // Ok, this is slower than python!! + // Partly this is because I am using a more general algorithm, + // which applies to any cdf + // But I am also using really anal convergence conditions. + // This could be optimized. free(seed); return 0; diff --git a/squiggle.c b/squiggle.c index 0c6d638..4d38779 100644 --- a/squiggle.c +++ b/squiggle.c @@ -1,6 +1,25 @@ #include #include +#include #include +#include +#include +#include +// #include +#define EXIT_ON_ERROR 0 +#define MAX_ERROR_LENGTH 500 +#define PROCESS_ERROR(...) \ + do { \ + if (EXIT_ON_ERROR) { \ + printf("@, in %s (%d)", __FILE__, __LINE__); \ + exit(1); \ + } else { \ + char error_msg[MAX_ERROR_LENGTH]; \ + snprintf(error_msg, MAX_ERROR_LENGTH, "@, in %s (%d)", __FILE__, __LINE__); \ + struct box error = { .empty = 1, .error_msg = error_msg }; \ + return error; \ + } \ + } while (0) // PI constant const float PI = M_PI; // 3.14159265358979323846; @@ -112,3 +131,171 @@ float mixture(float (*samplers[])(uint32_t*), float* weights, int n_dists, uint3 free(cumsummed_normalized_weights); return result; } + +// Sample from an arbitrary cdf +struct box { + int empty; + float content; + char* error_msg; +}; + +// Inverse cdf at point +// Two versions of this function: +// - raw, dealing with cdfs that return floats +// - input: cdf: float => float, p +// - output: Box(number|error) +// - box, dealing with cdfs that return a box. +// - input: cdf: float => Box(number|error), p +// - output: Box(number|error) +struct box inverse_cdf_float(float cdf(float), float p) +{ + // given a cdf: [-Inf, Inf] => [0,1] + // returns a box with either + // x such that cdf(x) = p + // or an error + // if EXIT_ON_ERROR is set to 1, it exits instead of providing an error + + float low = -1.0; + float high = 1.0; + + // 1. Make sure that cdf(low) < p < cdf(high) + int interval_found = 0; + while ((!interval_found) && (low > -FLT_MAX / 4) && (high < FLT_MAX / 4)) { + // ^ Using FLT_MIN and FLT_MAX is overkill + // but it's also the *correct* thing to do. + + int low_condition = (cdf(low) < p); + int high_condition = (p < cdf(high)); + if (low_condition && high_condition) { + interval_found = 1; + } else if (!low_condition) { + low = low * 2; + } else if (!high_condition) { + high = high * 2; + } + } + + if (!interval_found) { + PROCESS_ERROR("Interval containing the target value not found, in function inverse_cdf"); + } else { + + int convergence_condition = 0; + int count = 0; + while (!convergence_condition && (count < (INT_MAX / 2))) { + float mid = (high + low) / 2; + int mid_not_new = (mid == low) || (mid == high); + // float width = high - low; + // if ((width < 1e-8) || mid_not_new){ + if (mid_not_new) { + convergence_condition = 1; + } else { + float mid_sign = cdf(mid) - p; + if (mid_sign < 0) { + low = mid; + } else if (mid_sign > 0) { + high = mid; + } else if (mid_sign == 0) { + low = mid; + high = mid; + } + } + } + + if (convergence_condition) { + struct box result = { .empty = 0, .content = low }; + return result; + } else { + PROCESS_ERROR("Search process did not converge, in function inverse_cdf"); + } + } +} + +struct box inverse_cdf_box(struct box cdf_box(float), float p) +{ + // given a cdf: [-Inf, Inf] => Box([0,1]) + // returns a box with either + // x such that cdf(x) = p + // or an error + // if EXIT_ON_ERROR is set to 1, it exits instead of providing an error + + float low = -1.0; + float high = 1.0; + + // 1. Make sure that cdf(low) < p < cdf(high) + int interval_found = 0; + while ((!interval_found) && (low > -FLT_MAX / 4) && (high < FLT_MAX / 4)) { + // ^ Using FLT_MIN and FLT_MAX is overkill + // but it's also the *correct* thing to do. + struct box cdf_low = cdf_box(low); + if (cdf_low.empty) { + PROCESS_ERROR(cdf_low.error_msg); + } + + struct box cdf_high = cdf_box(high); + if (cdf_high.empty) { + PROCESS_ERROR(cdf_low.error_msg); + } + + int low_condition = (cdf_low.content < p); + int high_condition = (p < cdf_high.content); + if (low_condition && high_condition) { + interval_found = 1; + } else if (!low_condition) { + low = low * 2; + } else if (!high_condition) { + high = high * 2; + } + } + + if (!interval_found) { + PROCESS_ERROR("Interval containing the target value not found, in function inverse_cdf"); + } else { + + int convergence_condition = 0; + int count = 0; + while (!convergence_condition && (count < (INT_MAX / 2))) { + float mid = (high + low) / 2; + int mid_not_new = (mid == low) || (mid == high); + // float width = high - low; + if (mid_not_new) { + // if ((width < 1e-8) || mid_not_new){ + convergence_condition = 1; + } else { + struct box cdf_mid = cdf_box(mid); + if (cdf_mid.empty) { + PROCESS_ERROR(cdf_mid.error_msg); + } + float mid_sign = cdf_mid.content - p; + if (mid_sign < 0) { + low = mid; + } else if (mid_sign > 0) { + high = mid; + } else if (mid_sign == 0) { + low = mid; + high = mid; + } + } + } + + if (convergence_condition) { + struct box result = { .empty = 0, .content = low }; + return result; + } else { + PROCESS_ERROR("Search process did not converge, in function inverse_cdf"); + } + } +} + +// Sampler based on inverse cdf and randomness function +struct box sampler_box_cdf(struct box cdf(float), uint32_t* seed) +{ + float p = rand_0_to_1(seed); + struct box result = inverse_cdf_box(cdf, p); + return result; +} +struct box sampler_float_cdf(float cdf(float), uint32_t* seed) +{ + float p = rand_0_to_1(seed); + struct box result = inverse_cdf_float(cdf, p); + return result; +} diff --git a/squiggle.h b/squiggle.h index 96e70c1..4adc838 100644 --- a/squiggle.h +++ b/squiggle.h @@ -4,13 +4,29 @@ // uint32_t header #include +// Macros +#define EXIT_ON_ERROR 0 +#define MAX_ERROR_LENGTH 500 +#define PROCESS_ERROR(...) \ + do { \ + if (EXIT_ON_ERROR) { \ + printf("@, in %s (%d)", __FILE__, __LINE__); \ + exit(1); \ + } else { \ + char error_msg[MAX_ERROR_LENGTH]; \ + snprintf(error_msg, MAX_ERROR_LENGTH, "@, in %s (%d)", __FILE__, __LINE__); \ + struct box error = { .empty = 1, .error_msg = error_msg }; \ + return error; \ + } \ + } while (0) + // Pseudo Random number generator uint32_t xorshift32(uint32_t* seed); // Distribution & sampling functions float rand_0_to_1(uint32_t* seed); float rand_float(float max, uint32_t* seed); -float ur_normal(uint32_t* seed); +float unit_normal(uint32_t* seed); float random_uniform(float from, float to, uint32_t* seed); float random_normal(float mean, float sigma, uint32_t* seed); float random_lognormal(float logmean, float logsigma, uint32_t* seed); @@ -23,4 +39,20 @@ void array_cumsum(float* array_to_sum, float* array_cumsummed, int length); // Mixture function float mixture(float (*samplers[])(uint32_t*), float* weights, int n_dists, uint32_t* seed); +// Box +struct box { + int empty; + float content; + char* error_msg; +}; + +// Inverse cdf +struct box inverse_cdf_float(float cdf(float), float p); +struct box inverse_cdf_box(struct box cdf_box(float), float p); + +// Samplers from cdf +struct box sampler_box_cdf(struct box cdf(float), uint32_t* seed); +struct box sampler_float_cdf(float cdf(float), uint32_t* seed); + #endif +