diff --git a/examples/graph_size/data b/examples/graph_size/data new file mode 100644 index 00000000..98bf428f --- /dev/null +++ b/examples/graph_size/data @@ -0,0 +1,500 @@ +-42.1062906857, -42.0336814672 +-43.8005359566, -45.1675567427 +-42.4724208905, -41.9724295562 +-42.6648968305, -41.6057141673 +-42.3868633524, -41.3779388706 +-41.2875781684, -41.943036697 +-40.8688839541, -40.666920062 +-40.6544895126, -41.5523463732 +-40.9072622644, -39.8579458692 +-40.3513021146, -39.7746382213 +-38.3811209627, -37.6757363329 +-37.9229071229, -38.6732762515 +-38.3609946139, -38.4746416224 +-39.0521667891, -39.8775451578 +-39.2201881913, -39.204502537 +-39.0872934478, -37.2213211925 +-38.53791316, -39.5340020415 +-38.4577569364, -38.0705214195 +-39.8523579224, -38.3856073221 +-40.6947760351, -41.3031106763 +-40.4307265993, -41.178726689 +-38.6151982457, -39.1369324738 +-39.4955771716, -39.0180385962 +-39.1202606802, -39.9116074431 +-39.7738911208, -39.8200846432 +-38.483999758, -39.2282811522 +-37.5166898991, -36.9437599923 +-34.9063603488, -35.9415177929 +-35.2574799359, -35.4487007778 +-35.1139666888, -33.0484738046 +-35.6009604879, -35.9345178147 +-35.3058143188, -38.2191105421 +-33.2791812046, -33.9480237935 +-34.1735675991, -35.3708546275 +-34.2489507168, -33.7917890248 +-35.1609651494, -34.4166899953 +-34.715815402, -35.7197224739 +-33.8822593297, -35.9927879834 +-34.2794435086, -35.5288378089 +-33.8970790044, -32.3817670965 +-32.6331014014, -32.8301360602 +-34.0663028979, -35.3378009513 +-33.7060297087, -34.2228377192 +-33.3171866607, -32.4453329821 +-33.1441032756, -31.2857533221 +-31.9518157511, -31.8794969872 +-31.5931160031, -32.0507993278 +-31.347284575, -31.2381091063 +-32.1164977965, -32.6249320402 +-32.5314962549, -32.6208620849 +-32.5511974463, -31.0853623839 +-34.0949958401, -35.3547582148 +-33.7735809722, -35.7588715044 +-32.5808879496, -32.5288759074 +-31.7376727892, -31.8153852254 +-31.3431161891, -32.4544654458 +-30.3849309755, -31.10584215 +-32.0073363936, -31.5664331416 +-31.191618407, -31.0178378171 +-34.1112330473, -33.4037874507 +-35.4996826338, -36.6565279258 +-34.331954454, -34.0361562005 +-34.5326403226, -34.5113654554 +-35.3101873141, -34.7941467664 +-34.7086437456, -34.9184089409 +-35.1759310185, -35.3039487668 +-35.061765007, -35.0447409264 +-34.8730062712, -34.7784970848 +-33.6851453926, -32.8753491812 +-33.2509029011, -33.533045941 +-32.583009551, -34.2221158455 +-33.5811268249, -33.2235064529 +-32.6208071631, -32.6069613132 +-31.6225455437, -34.127442424 +-30.5215387396, -31.7075884355 +-30.0352171671, -30.2798055406 +-30.4034597768, -30.8648394138 +-32.1724922104, -33.9037457664 +-33.5139539326, -33.3514839438 +-33.0310891361, -34.1364527721 +-32.0874091233, -32.7919986748 +-32.0390834149, -31.5833160881 +-32.858534182, -33.1238104216 +-35.2475652743, -36.030089323 +-34.4223935973, -33.8249414218 +-33.7191546225, -33.412188164 +-33.1644631012, -31.9425509236 +-30.9335331759, -30.128585149 +-31.3151573046, -30.0625074095 +-30.0691513871, -29.9336303681 +-29.2587695144, -28.7417348036 +-27.149541748, -26.5617559112 +-25.4388874528, -25.577305528 +-26.4525439159, -26.9081854362 +-24.6487652565, -25.0626630698 +-25.3482981681, -23.4346832166 +-26.1828421027, -26.0609424322 +-24.911286586, -24.4736669157 +-24.6236352983, -25.5635757554 +-25.4161698309, -24.9246871756 +-25.1956945856, -22.8949023013 +-24.343257184, -24.279319482 +-25.4610471508, -25.2664162922 +-24.9143941019, -25.4684990031 +-25.8167457659, -24.7442813358 +-25.5678347988, -24.9548183805 +-26.2465245804, -27.2305193544 +-27.9835454695, -27.1692484289 +-28.3903532434, -27.8213503725 +-28.3400577976, -29.3135473352 +-29.4761290608, -28.5574098299 +-29.6461310311, -29.5404569648 +-29.2777492301, -29.0648306613 +-29.2245451006, -29.7001437097 +-28.5771024844, -27.9649594933 +-27.8788455841, -27.6414702425 +-26.7770743241, -27.6979786838 +-29.3811831744, -29.0925772145 +-28.9060211218, -31.1561015383 +-28.0954226301, -28.5765410394 +-26.841042985, -27.4637768194 +-27.4970454542, -25.170442093 +-27.6219578552, -26.8830693978 +-28.5028285063, -28.899553903 +-28.9444221866, -28.955019796 +-28.5945811877, -25.1930306214 +-28.0645762886, -29.0048189769 +-29.2694375546, -28.6852751351 +-28.4725617574, -30.0814775143 +-28.1896347315, -29.9117622782 +-28.6803498264, -26.6559922175 +-27.8272377841, -26.0854433255 +-27.5531270078, -27.2916114581 +-28.3933323265, -29.3686502349 +-28.6761104906, -28.7931835838 +-28.5644833979, -29.8528160868 +-28.5914190664, -29.1071504235 +-29.1194521883, -29.6410189086 +-29.7733389199, -30.8875864944 +-30.1929613963, -30.7949759923 +-29.0754440487, -29.2003553493 +-30.1535434149, -31.6850269587 +-28.926679289, -29.0640604804 +-27.5537072757, -28.5559527167 +-27.1642822467, -26.1915996491 +-25.7717005527, -25.3841012191 +-26.8806644126, -26.2664867871 +-27.5788152447, -27.544656683 +-26.6795896796, -26.7665024784 +-27.8954693414, -26.2792813085 +-28.3862463654, -29.0503667823 +-27.8325111926, -28.0887610381 +-28.2162825528, -27.7037211175 +-27.5759980334, -27.5674001586 +-27.5963108883, -28.4858654796 +-27.6806186588, -28.5175805504 +-27.4264760641, -27.5557197497 +-27.8533510567, -27.7975810237 +-28.6610481656, -27.1910641749 +-30.3028641175, -30.3458477708 +-29.0807960863, -29.4372632273 +-30.4038464456, -32.42014751 +-29.4696142854, -29.5124983279 +-30.6955187035, -31.4443252999 +-31.0301102053, -29.5167335852 +-29.4349888618, -28.9406743638 +-28.7063818334, -27.9781402069 +-28.808576396, -30.0586458227 +-30.0811080604, -28.9126917175 +-28.2343237056, -26.8991548979 +-28.6279813437, -29.0396281724 +-29.2538691255, -29.5057669179 +-30.3642721421, -30.1102203023 +-30.9549181577, -31.124548036 +-31.9057107575, -32.4965182495 +-31.4498211797, -32.6238766097 +-30.5495543457, -30.5972062623 +-31.1281116132, -33.3422275229 +-30.0992379254, -31.3686079921 +-29.5025016404, -28.4322467055 +-29.8245769898, -29.8097619741 +-29.3067478683, -28.7238642651 +-27.1694667384, -27.5995735429 +-25.8329261718, -25.7408732555 +-26.0276883855, -26.0770351201 +-26.9406004993, -27.1825939088 +-26.7743543113, -26.5666432799 +-27.8744247961, -28.8638414234 +-27.6630916554, -27.6189380144 +-29.7000259084, -28.0893413727 +-28.37539817, -28.5748712177 +-25.7365628938, -23.788588806 +-25.6346367152, -25.9963380661 +-26.5165192643, -29.386554196 +-26.8501942791, -27.3048776481 +-26.9976619112, -26.4401737012 +-27.7278769551, -28.3458030691 +-26.884174628, -26.6737479531 +-27.8384280309, -29.1714983155 +-28.109138205, -27.6354772732 +-25.9158047352, -26.4391144369 +-26.8036667494, -29.0356951236 +-25.407441713, -26.828763888 +-26.5433980582, -24.9932054087 +-27.7973918816, -27.8551876296 +-28.4577417359, -26.9970384834 +-29.728432159, -29.6138207174 +-29.203671105, -31.3386915407 +-28.213054479, -28.3867449355 +-27.8165981375, -26.9830706444 +-29.3344904327, -30.1899244569 +-28.1305534862, -26.8253615576 +-26.0905465846, -24.1397639191 +-26.9652411025, -26.6665361436 +-27.6595046377, -29.8607321894 +-27.7630733636, -26.9486052728 +-27.1475604888, -25.9716064436 +-26.5945680645, -27.070631818 +-27.6541337096, -28.0835629198 +-28.0666060441, -28.0678209774 +-28.6223834359, -28.4708110181 +-28.0087778844, -27.2368197794 +-27.0243969511, -26.8110022384 +-27.0736925293, -27.1971634197 +-25.984575064, -25.9942576644 +-25.6213017137, -26.6619070654 +-25.6671751974, -26.5709333309 +-26.5931345936, -26.1612330163 +-27.2701425142, -27.3185999438 +-28.0044601569, -27.6322463041 +-28.9850648485, -30.1626686867 +-28.5478951801, -29.7365570567 +-27.4032971146, -27.3008368132 +-26.8264671697, -26.9605797913 +-26.0364546726, -25.9071647029 +-24.9441580929, -25.6903052645 +-25.1446038853, -26.2547647056 +-26.5096664032, -25.8741371091 +-25.4955121402, -25.0352090808 +-23.9674484218, -23.0128640303 +-23.4831234069, -25.7140626129 +-23.0820783863, -23.7214718186 +-23.6288357805, -23.0958747742 +-24.525473866, -22.6015239583 +-23.8997870625, -23.280662377 +-23.5425419709, -23.7146507638 +-23.2264212596, -24.2184850738 +-23.2403193631, -23.2219310816 +-25.2395812762, -24.709715186 +-25.4318798465, -24.3778701334 +-24.4792600791, -25.9333098083 +-24.0889521241, -23.9644393612 +-24.5946995546, -25.0032653577 +-24.4228352134, -25.6226369989 +-25.250134446, -25.551608645 +-26.9503296954, -27.8701908004 +-28.6246891615, -28.1421936005 +-29.4569695802, -27.8607424088 +-30.2675838872, -29.8030765345 +-30.6791019198, -31.8474776272 +-31.3260230276, -32.5273134893 +-31.5489599271, -31.8425575525 +-32.2258437855, -32.2763683494 +-31.4269418816, -31.1674808786 +-32.2525804309, -32.4514063305 +-32.8709590329, -34.0024450031 +-31.7042237698, -32.004009662 +-32.0417080487, -32.375350767 +-33.7327222106, -30.9047163511 +-33.3311334979, -32.8906956456 +-33.3756419167, -34.0048407755 +-33.0868799744, -33.7209653233 +-35.1994234219, -33.9784147409 +-34.7642312495, -34.4171546542 +-35.5197235314, -34.9790707796 +-34.2151548166, -34.2400948744 +-33.0926883021, -35.1320330684 +-33.633773712, -32.7970000537 +-31.2682448033, -30.6292746603 +-31.0180250389, -30.3401132077 +-32.084570411, -33.5238311834 +-34.6673258102, -35.0937365659 +-33.5012579861, -33.340650492 +-34.6180080299, -36.0884485165 +-35.232426059, -34.6986115794 +-34.8924939148, -36.6700460457 +-34.8015683141, -35.909094015 +-35.5319332443, -35.5284850552 +-34.7886911121, -34.9053897374 +-34.6307228998, -34.845631327 +-33.0432938766, -32.6489745301 +-31.8237675451, -31.0858242087 +-32.692865082, -34.5799896469 +-33.0707898099, -33.184497 +-33.7197161189, -33.518085965 +-34.4973983184, -35.6687689824 +-34.9038206831, -36.2425984096 +-34.6834056296, -35.028800414 +-35.7572601182, -34.4613260214 +-36.4904011535, -37.0215512617 +-35.724763458, -34.7277162645 +-36.4773505669, -35.5102092337 +-37.5366710993, -37.003637724 +-39.8233595679, -39.5335629103 +-37.8602887647, -38.0875546401 +-37.2116764502, -36.1222446101 +-39.2261152025, -39.0109762372 +-40.338815215, -40.460308523 +-39.4688678968, -39.6488203898 +-40.7187471212, -41.7877086305 +-40.3498826646, -40.6970963772 +-39.706716719, -40.8080416914 +-39.8821984518, -39.5134615805 +-40.1355461533, -39.9324909783 +-39.0188114496, -40.1861574857 +-40.0357893528, -40.6334911463 +-39.1608752029, -39.0843793693 +-39.3342184444, -39.350719713 +-41.1744196464, -40.7189024988 +-41.418598466, -39.5466014058 +-39.9227366368, -39.6176780129 +-39.7145903496, -42.2054891719 +-40.0423145046, -39.8260555687 +-40.8021042038, -40.3250264334 +-40.7463477928, -42.411737623 +-40.241495777, -38.9772150581 +-39.8792685529, -40.1735049111 +-39.3546669122, -40.4736640123 +-39.9634912599, -39.1577241314 +-40.0019120821, -40.2391703124 +-41.1835631962, -41.6743794197 +-42.8477666916, -41.3144877174 +-42.8686958128, -43.7700743052 +-43.1529412449, -44.0563760647 +-44.7100218542, -44.828001253 +-45.7475947292, -47.0247575135 +-45.8700849646, -45.1880391877 +-46.7419685835, -46.4851733017 +-47.2151221027, -48.3638272282 +-46.5693869237, -45.9765751346 +-47.2169996427, -46.1887852873 +-48.7350800511, -48.0848163535 +-49.6663216711, -49.484483849 +-50.3157885886, -50.4178948668 +-49.5835478266, -48.8257321326 +-49.9189778527, -50.3653920606 +-49.2709053294, -49.5922778894 +-50.2034979012, -48.931265447 +-51.873663908, -51.8216371332 +-51.2586060044, -50.9734396181 +-52.6417032681, -53.0198864479 +-53.0495701655, -52.9849768815 +-51.7512264905, -52.2169215956 +-51.3992457957, -51.1873584381 +-48.8866771104, -50.9721397562 +-49.218059746, -49.9439380443 +-49.7580310208, -49.1182295322 +-49.4920851265, -48.2504326444 +-51.0019553612, -51.3281315438 +-49.9135759025, -49.0982112238 +-50.9725301087, -51.7547601608 +-50.7217572031, -52.4112276102 +-50.5194981139, -48.6245545952 +-49.385169457, -47.7557851584 +-50.7426582726, -51.0166502326 +-52.8433813643, -53.114183952 +-51.1527885406, -51.186656181 +-50.3241251673, -51.0180504006 +-48.5753299536, -47.962045343 +-48.207679923, -49.6448510359 +-49.1952657092, -50.275612414 +-48.9777524804, -48.5822605069 +-48.9324688648, -50.3417104246 +-49.6609423277, -51.6043624534 +-50.2626362408, -50.737545874 +-51.4311725609, -52.3724640355 +-52.8607103509, -53.1442489082 +-54.1779891036, -51.815497295 +-54.219049686, -54.3679404949 +-54.0650255842, -55.0474990054 +-53.6139337953, -54.1046288854 +-52.7713447095, -52.111538926 +-51.9747134137, -51.6985508288 +-53.7878322801, -53.3991865059 +-54.5587760487, -54.4257835011 +-53.6656146566, -52.1550585068 +-52.5667643043, -52.7983659955 +-53.5661608509, -53.8509876401 +-54.9371277794, -53.6966099828 +-54.3548058119, -54.6149710442 +-52.6531922092, -52.152868455 +-53.4044409113, -54.3836762094 +-55.4207430952, -55.2513531376 +-54.572260519, -55.0144657533 +-54.5998563225, -55.6401968275 +-54.2880402388, -54.5710403643 +-53.2404566001, -51.8503109176 +-54.1707442252, -54.8126613671 +-53.610565492, -52.825864764 +-53.3920800597, -53.7765145283 +-52.1625852769, -54.258492125 +-52.9429105507, -54.7388696799 +-52.2333718708, -51.8914609009 +-52.0840542658, -50.5778465696 +-51.304456867, -53.2397318512 +-51.4643161275, -52.0254090912 +-51.531743689, -51.7879215129 +-49.5126781952, -49.2716512842 +-49.6017643936, -47.528329564 +-47.3547720687, -47.1927853676 +-46.9468957892, -45.5096770119 +-46.0758007778, -46.4608802748 +-45.223937188, -44.6021570868 +-43.3054195558, -41.787110564 +-42.957858299, -40.9587112232 +-41.6109608724, -42.4086138754 +-39.4116045338, -39.09711946 +-41.831892172, -42.1068728743 +-41.3683996212, -39.0205923131 +-42.3861375331, -42.6965088439 +-42.6189425569, -41.4944245319 +-43.1006315465, -43.7112897132 +-43.9067488501, -44.4887062135 +-43.7114613514, -43.5845911821 +-44.5412903456, -45.3155911146 +-46.576528901, -46.0035876452 +-48.8506207882, -49.1834378554 +-47.657853524, -48.5379138688 +-48.7764780791, -49.5646650762 +-48.0797345209, -47.9775177988 +-48.1845304319, -47.9443699591 +-47.7390173834, -47.3005093641 +-46.8115242581, -46.5277020283 +-47.2726005293, -48.5238092692 +-48.2657758523, -48.2851160532 +-47.7104315093, -47.4969546198 +-49.9398509007, -50.2167969607 +-48.4528502997, -48.463990757 +-49.6398611243, -49.2780137235 +-48.0076768779, -47.4965093847 +-47.4553287896, -49.6855525676 +-48.0504670285, -46.505361751 +-46.9427489541, -47.8926679184 +-45.9625056675, -46.5903837292 +-46.262015636, -47.0956836747 +-45.3584506436, -44.4813252659 +-45.16300305, -45.8239578303 +-45.1552924844, -44.3776361388 +-46.267331778, -44.8311593774 +-46.2375590233, -46.4814983695 +-45.8737533246, -47.3342284513 +-45.9357554245, -45.6825629973 +-46.5663695938, -47.5700997421 +-46.8845558339, -47.5001354825 +-47.3659053489, -47.7904764322 +-47.5111368411, -46.9367875192 +-47.4935459974, -47.5892807589 +-48.252890548, -47.2035610018 +-48.2622386143, -48.7436761592 +-47.9811485009, -47.6117463921 +-48.6887593962, -49.5391794799 +-49.3289892754, -50.8319143546 +-50.2017126615, -50.4070080842 +-50.3476758829, -50.3676823485 +-49.3773682591, -49.7165363768 +-50.2899512234, -50.3649639086 +-51.4749650344, -52.2926140239 +-52.3537989376, -53.0480033676 +-51.8301096869, -51.9156215216 +-50.9988377036, -50.6988016317 +-49.8237442664, -50.0374828236 +-50.1187756697, -49.8274883739 +-53.2679656864, -52.6532479545 +-53.1408928348, -51.9907147106 +-51.9312233116, -52.2472835012 +-53.0521270719, -51.5958665745 +-51.8108226296, -51.8476365612 +-50.8069068842, -51.6954338887 +-50.5420845371, -49.9485951437 +-51.8329174347, -52.0035365548 +-50.3023721644, -51.0807516291 +-48.841264703, -48.9287445976 +-49.6885871322, -50.347174943 +-49.705146179, -49.6960113375 +-48.5314094897, -50.9751234783 +-49.2015782511, -48.2600018347 +-47.5698208592, -46.0439618662 +-48.1585537725, -49.7888608764 +-47.6548040904, -47.3748040093 +-50.0274925351, -48.9028683395 +-48.9625968441, -47.6413350883 +-49.8379193172, -50.5979851811 +-51.0704651754, -53.0822280628 +-50.9261708075, -51.0067205884 +-49.950390372, -51.4738277079 +-49.3314110559, -47.6192998846 +-48.0300347478, -48.3970805467 +-46.7221904414, -44.8862217086 +-45.1540460544, -44.5661098443 +-45.5667176925, -46.4544737983 diff --git a/examples/graph_size/dune b/examples/graph_size/dune new file mode 100644 index 00000000..e2475cb4 --- /dev/null +++ b/examples/graph_size/dune @@ -0,0 +1,7 @@ +(env + (dev + (flags (:standard -w -9-27-33 -warn-error -A)))) + +(executables + (names main) + (libraries probzelus)) diff --git a/examples/graph_size/gen_data.zls b/examples/graph_size/gen_data.zls new file mode 100644 index 00000000..0872dc9a --- /dev/null +++ b/examples/graph_size/gen_data.zls @@ -0,0 +1,26 @@ +(* + * Copyright 2018-2020 IBM Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + *) + +open Probzelus +open Distribution + +let random_init = Random.self_init () + +let node main () = + let init x0 = Distribution.draw (gaussian (0., 2500.)) in + let rec x = x0 -> Distribution.draw (gaussian (pre x, 1.)) in + let y = Distribution.draw (gaussian (x, 1.)) in + print_string ((string_of_float x) ^ ", " ^ (string_of_float y) ^ "\n") diff --git a/examples/graph_size/kalman_copy/Makefile b/examples/graph_size/kalman_copy/Makefile new file mode 100644 index 00000000..6bc29aff --- /dev/null +++ b/examples/graph_size/kalman_copy/Makefile @@ -0,0 +1,28 @@ +EX=kalman +ALGO=copy +NAME=kalman_copy +INFERLIB=../../../inference +ZELUC=zeluc -copy -I $(INFERLIB) +ZELUC += -I ../kalmangslib + +$(NAME).ml : $(NAME.zls) + $(ZELUC) -noreduce $(NAME).zls + +%.zci : %.zli + $(ZELUC) $< + +build: $(NAME).ml run.ml + dune build run.exe + +exec: build + dune exec ./run.exe < ../data + +clean: + -rm -f *.zc* + -rm -f $(NAME).ml + +cleanall: clean + dune clean + -rm -f *~ + +all: build diff --git a/examples/graph_size/kalman_copy/dune b/examples/graph_size/kalman_copy/dune new file mode 100644 index 00000000..1adb8250 --- /dev/null +++ b/examples/graph_size/kalman_copy/dune @@ -0,0 +1,7 @@ +(env + (dev + (flags (:standard -w -9-27-33 -warn-error -A)))) + +(executables + (names run) + (libraries probzelus kalmangslib)) diff --git a/examples/graph_size/kalman_copy/kalman_copy.zls b/examples/graph_size/kalman_copy/kalman_copy.zls new file mode 100644 index 00000000..77323b9a --- /dev/null +++ b/examples/graph_size/kalman_copy/kalman_copy.zls @@ -0,0 +1,31 @@ +(* + * Copyright 2018-2020 IBM Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + *) + +open Probzelus +open Distribution +open Infer_ds_streaming_copy +open Kalmangslib + +let proba kalman1d yobs = xt where + rec xt = sample (gaussian ((const 0., 2500.) -> (pre xt, 1.))) + and () = observe (gaussian (xt, 1.), yobs) + +let node main_no_metric particles observed = + infer particles kalman1d observed + +let node main particles (true_x, observed) = (d, mse) where + rec d = main_no_metric particles observed + and mse = Metrics.mse (true_x, d) diff --git a/examples/graph_size/kalman_copy/run.ml b/examples/graph_size/kalman_copy/run.ml new file mode 100644 index 00000000..2ec1eaf1 --- /dev/null +++ b/examples/graph_size/kalman_copy/run.ml @@ -0,0 +1,42 @@ +(* + * Copyright 2018-2020 IBM Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + *) + +let name = "Kalman-1D" +let algo = "SDS" +type input = float * float +type output = float Probzelus.Distribution.t * float +let read_input () = Scanf.scanf ("%f, %f\n") (fun t o -> (t, o)) +let main = Kalman_copy.main +let string_of_output (out, _) = string_of_float (Probzelus.Distribution.mean_float out) + +let num_particles = 5 + + +let rec run_helper step state = + try + let s = read_input () in + let out = step state s in + print_string ((string_of_output out) ^ "\n"); + run_helper step state + with End_of_file -> [] + +let run _ = + let Ztypes.Cnode {alloc; reset; step; copy = _} = main num_particles in + let init_state = alloc () in + reset init_state; + run_helper step init_state;; + +run () diff --git a/examples/graph_size/kalman_copy_instrumented/Makefile b/examples/graph_size/kalman_copy_instrumented/Makefile new file mode 100644 index 00000000..e5737a4a --- /dev/null +++ b/examples/graph_size/kalman_copy_instrumented/Makefile @@ -0,0 +1,30 @@ +EX=kalman +ALGO=copy +NAME=kalman_copy_instrumented +INFERLIB=../../../inference +ZELUC=zeluc -copy -I $(INFERLIB) +ZELUC += -I ../kalmangslib +PROBZELUC = probzeluc +PROBZELUC += -I ../kalmangslib + +$(NAME).ml : $(NAME.zls) + $(PROBZELUC) -noreduce -nopt -nosimplify -inline 0 $(NAME).zls + +%.zci : %.zli + $(PROBZELUC) $< + +build: $(NAME).ml run.ml + dune build run.exe + +exec: build + dune exec ./run.exe < ../data + +clean: + -rm -f *.zc* + -rm -f $(NAME).ml + +cleanall: clean + dune clean + -rm -f *~ + +all: build diff --git a/examples/graph_size/kalman_copy_instrumented/dune b/examples/graph_size/kalman_copy_instrumented/dune new file mode 100644 index 00000000..216ad1ee --- /dev/null +++ b/examples/graph_size/kalman_copy_instrumented/dune @@ -0,0 +1,8 @@ +(env + (dev + (flags (:standard -w -9-27-33 -warn-error -A)))) + +(executables + (names run) + (libraries probzelus kalmangslib) + (modes native)) diff --git a/examples/graph_size/kalman_copy_instrumented/kalman_copy_instrumented.zls b/examples/graph_size/kalman_copy_instrumented/kalman_copy_instrumented.zls new file mode 100644 index 00000000..4b61671f --- /dev/null +++ b/examples/graph_size/kalman_copy_instrumented/kalman_copy_instrumented.zls @@ -0,0 +1,37 @@ +(* + * Copyright 2018-2020 IBM Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + *) + +open Probzelus +open Distribution +open Infer_ds_streaming_copy_instrumented +open Kalmangslib + +let atomic proba do_kalman1d yobs = xt where + rec xt = sample (gaussian ((const 0., 2500.) -> (pre xt, 1.))) + and () = observe (gaussian (xt, 1.), yobs) + +let proba kalman1d yobs = xt where + rec xt = do_kalman1d yobs + and null1 = Utils.to_unit xt + and null2 = Utils.gc_full_major null1 + and () = print_ins null2 + +let node main_no_metric particles observed = + infer particles kalman1d observed + +let node main particles (true_x, observed) = (d, mse) where + rec d = main_no_metric particles observed + and mse = Metrics.mse (true_x, d) diff --git a/examples/graph_size/kalman_copy_instrumented/run.ml b/examples/graph_size/kalman_copy_instrumented/run.ml new file mode 100644 index 00000000..4b9fa2c9 --- /dev/null +++ b/examples/graph_size/kalman_copy_instrumented/run.ml @@ -0,0 +1,41 @@ +(* + * Copyright 2018-2020 IBM Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + *) + +let name = "Kalman-1D" +let algo = "SDS" +type input = float * float +type output = float Probzelus.Distribution.t * float +let read_input () = Scanf.scanf ("%f, %f\n") (fun t o -> (t, o)) +let main = Kalman_copy_instrumented.main +let string_of_output (out, _) = string_of_float (Probzelus.Distribution.mean_float out) + +let num_particles = 2 + +let rec run_helper step state = + try + let s = read_input () in + let out = step state s in + print_string ("Output:" ^ (string_of_output out) ^ "\n"); + run_helper step state + with End_of_file -> [] + +let run _ = + let Ztypes.Cnode {alloc; reset; step; copy = _} = main num_particles in + let init_state = alloc () in + reset init_state; + run_helper step init_state;; + +run () diff --git a/examples/graph_size/kalmangslib/Makefile b/examples/graph_size/kalmangslib/Makefile new file mode 100644 index 00000000..2f66d94e --- /dev/null +++ b/examples/graph_size/kalmangslib/Makefile @@ -0,0 +1,30 @@ +INFERLIB=../../../inference +OWLLIB=../../../owl +ZELUC=zeluc -copy -I $(INFERLIB) -I $(OWLLIB) +PROBZELUC=probzeluc + + +ZLI=$(wildcard *.zli) +ZCI=$(ZLI:zli=zci) + +all: $(ZCI) byte opt + +.phony: byte opt + +byte: metrics.ml + dune build kalmangslib.cma + +opt: metrics.ml + dune build kalmangslib.cmxa + +%.zci: %.zli + $(PROBZELUC) $< + +metrics.ml : metrics.zls + $(PROBZELUC) -noreduce metrics.zls + +clean: + dune clean + -rm -f *.zci metrics.ml +cleanall: clean + rm -f *~ diff --git a/examples/graph_size/kalmangslib/dune b/examples/graph_size/kalmangslib/dune new file mode 100644 index 00000000..e22a32ad --- /dev/null +++ b/examples/graph_size/kalmangslib/dune @@ -0,0 +1,3 @@ +(library + (name kalmangslib) + (libraries probzelus)) diff --git a/examples/graph_size/kalmangslib/kalmangslib.zli b/examples/graph_size/kalmangslib/kalmangslib.zli new file mode 100644 index 00000000..d13568c9 --- /dev/null +++ b/examples/graph_size/kalmangslib/kalmangslib.zli @@ -0,0 +1 @@ +(* Empty file *) diff --git a/examples/graph_size/kalmangslib/metrics.zls b/examples/graph_size/kalmangslib/metrics.zls new file mode 100644 index 00000000..8e6398ce --- /dev/null +++ b/examples/graph_size/kalmangslib/metrics.zls @@ -0,0 +1,25 @@ +(* + * Copyright 2018-2020 IBM Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + *) + +open Probzelus +open Distribution + +let node mse (true_x, d) = mse where + rec t = 1. fby (t +. 1.) + and estimated_x = mean_float d + and error = (estimated_x -. true_x) ** 2. + and total_error = error -> (pre total_error) +. error + and mse = total_error /. t diff --git a/examples/graph_size/kalmangslib/utils.ml b/examples/graph_size/kalmangslib/utils.ml new file mode 100644 index 00000000..a47fceb1 --- /dev/null +++ b/examples/graph_size/kalmangslib/utils.ml @@ -0,0 +1,3 @@ +let to_unit _ = () +let gc_full_major _ = + Gc.full_major (); diff --git a/examples/graph_size/kalmangslib/utils.zli b/examples/graph_size/kalmangslib/utils.zli new file mode 100644 index 00000000..96a5c1a4 --- /dev/null +++ b/examples/graph_size/kalmangslib/utils.zli @@ -0,0 +1,2 @@ +val to_unit : 'a -> unit +val gc_full_major : unit -> unit diff --git a/probzelus/inference/ds_streaming_low_level_instrumented.ml b/probzelus/inference/ds_streaming_low_level_instrumented.ml new file mode 100644 index 00000000..a4b859a0 --- /dev/null +++ b/probzelus/inference/ds_streaming_low_level_instrumented.ml @@ -0,0 +1,160 @@ +(* + * Copyright 2018-2020 IBM Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + *) + +type instrumentation = instrumentation_impl ref +and instrumentation_impl ={ + weaks : (Obj.t Weak.t * int) list; + finaliser_count : int ref +} + +type pstate = { + ds_state : Ds_streaming_low_level.pstate; + nodes : instrumentation +} + +let mk_pstate pf_state ins_state = { + ds_state = pf_state; + nodes = ins_state +} + +let print_ins_helper ins = +(* + let nodes_alive_weak = + List.fold_left (fun s w -> + if Weak.check w 0 then + s + 1 + else + s + ) 0 (!ins).weaks + in +*) + let nodes_alive_weak = + List.fold_left (fun s (w, i) -> + match Weak.get w 0 with + | None -> s + | Some _ -> s + 1 + ) 0 (!ins).weaks + in + let nodes_alive_finaliser = !((!ins).finaliser_count) in + print_string ("Nodes alive weak: " ^ (string_of_int nodes_alive_weak) ^ "; Nodes alive finaliser: " ^ (string_of_int nodes_alive_finaliser) ^ "\n") + +let print_ins pstate = print_ins_helper pstate.nodes + + +let empty_ins _ = ref { weaks = []; finaliser_count = ref 0 } +let clear ins = ins := { weaks = []; finaliser_count = ref 0 } +let copy_ins src dst = dst := !src + +let get_distr_kind = Ds_streaming_low_level.get_distr_kind +let get_distr = Ds_streaming_low_level.get_distr + +let factor' (pstate, f0) = Infer_pf.factor' (pstate.ds_state, f0) + +let value = Ds_streaming_low_level.value + +let add_all_nodes : type a. instrumentation -> (a, Obj.t) Hashtbl.t -> unit = + fun ins tbl -> + (*(print_string "---ADDING ALL NODES---\n"); + (print_string ("Initial finaliser: " ^ (string_of_int (!((!ins).finaliser_count))) ^ "\n"));*) + Hashtbl.iter (fun _ o -> + let w = Weak.create 1 in + Weak.set w 0 (Some o); + ins := { + weaks = (w, -1) :: (!ins).weaks; + finaliser_count = (!ins).finaliser_count + }; + let finaliser_count = (!ins).finaliser_count in + finaliser_count := !finaliser_count + 1; + Gc.finalise (fun _ -> + (*(print_string "---add_all_nodes FINALISER---\n");*) + finaliser_count := !finaliser_count - 1 + ) o + ) tbl + (*(print_string ("End finaliser: " ^ (string_of_int (!((!ins).finaliser_count))) ^ "\n")); + (print_string "---END ADD ALL NODES---\n")*) + +let assume_constant : type a p. + pstate -> a Types.mdistr -> (p, a) Types.ds_graph_node = + fun ps d -> + let ret = Ds_streaming_low_level.assume_constant d in + (*(print_string ("---ASSUME CONSTANT---: " ^ (string_of_int ret.ds_graph_node_id) ^ "\n"));*) + let w = Weak.create 1 in + Weak.set w 0 (Some (Obj.repr ret)); + (*Gc.finalise (fun _ -> ()) ret; (* Add a finaliser for the node to ensure + the node was heap-allocated. + Otherwise, the value will be copied into + the weak array and the weak pointer will + not properly indicate memory usage. *)*) + (*Gc.finalise (fun _ -> (print_string "Finalizer called!\n")) ret;*) + + let finaliser_count = (!(ps.nodes)).finaliser_count in + finaliser_count := !finaliser_count + 1; + (*let node_num = ret.ds_graph_node_id in*) + + Gc.finalise (fun _ -> + (*(print_string ("---assume_constant FINALISER for node " ^ (string_of_int node_num) ^ "---\n"));*) + finaliser_count := !finaliser_count - 1 + ) ret; + + ps.nodes := { + weaks = (w, ret.ds_graph_node_id) :: (!(ps.nodes)).weaks; + finaliser_count = !(ps.nodes).finaliser_count + }; + + ret + +let assume_conditional : type a b c. + pstate -> (a, b) Types.ds_graph_node -> (b, c) Types.cdistr -> (b, c) Types.ds_graph_node = + fun ps p cdistr -> + let ret = Ds_streaming_low_level.assume_conditional p cdistr in + (*(print_string ("---ASSUME CONDITIONAL---: " ^ (string_of_int ret.ds_graph_node_id) ^ "\n"));*) + let w = Weak.create 1 in + Weak.set w 0 (Some (Obj.repr ret)); + (*Gc.finalise (fun _ -> ()) ret; (* Add a finaliser for the node to ensure + the node was heap-allocated. + Otherwise, the value will be copied into + the weak array and the weak pointer will + not properly indicate memory usage. *)*) + (*Gc.finalise (fun _ -> (print_string "Finalizer called!\n")) ret;*) + + let finaliser_count = (!(ps.nodes)).finaliser_count in + finaliser_count := !finaliser_count + 1; + (*let node_num = ret.ds_graph_node_id in*) + + Gc.finalise (fun _ -> + (*(print_string ("---assume_conditional FINALISER for node " ^ (string_of_int node_num) ^ "---\n"));*) + finaliser_count := !finaliser_count - 1 + ) ret; + + ps.nodes := { + weaks = (w, ret.ds_graph_node_id) :: (!(ps.nodes)).weaks; + finaliser_count = (!(ps.nodes)).finaliser_count + }; + ret + +(* TODO(eatkinson): use Ds_streaming_low_level.observe_with_graft once it is merged *) +let observe_with_graft : type a b. + pstate -> b -> (a, b) Types.ds_graph_node -> unit = + fun prob x n -> + Ds_streaming_low_level.graft n; + Ds_streaming_low_level.observe prob.ds_state x n + +let observe_conditional : type a b c. + pstate -> (a, b) Types.ds_graph_node -> (b, c) Types.cdistr -> c -> unit = + fun prob p cdistr obs -> + (*(print_string ("---OBSERVE CONDITIONAL---: " ^ (string_of_int p.ds_graph_node_id) ^ "\n"));*) + Ds_streaming_low_level.observe_conditional prob.ds_state p cdistr obs + diff --git a/probzelus/inference/dune b/probzelus/inference/dune index ec549ef9..0b553fc8 100644 --- a/probzelus/inference/dune +++ b/probzelus/inference/dune @@ -4,8 +4,8 @@ (action (bash "zeluc -I `zeluc -where`-owl %{zli}"))) (rule - (deps (:zli infer_ds_naive.zli infer_ds_streaming.zli infer_ds_streaming_copy.zli infer_importance.zli infer_pf.zli) distribution.zci) - (targets infer_ds_naive.zci infer_ds_streaming.zci infer_ds_streaming_copy.zci infer_importance.zci infer_pf.zci) + (deps (:zli infer_ds_naive.zli infer_ds_streaming.zli infer_ds_streaming_copy.zli infer_ds_streaming_copy_instrumented.zli infer_importance.zli infer_pf.zli) distribution.zci) + (targets infer_ds_naive.zci infer_ds_streaming.zci infer_ds_streaming_copy.zci infer_ds_streaming_copy_instrumented.zci infer_importance.zci infer_pf.zci) (action (bash "zeluc -I `zeluc -where`-owl %{zli}"))) (rule @@ -21,4 +21,4 @@ (install (package probzelus) (section share) - (files distribution.zci infer_ds_naive.zci infer_ds_streaming.zci infer_ds_streaming_copy.zci infer_importance.zci infer_pf.zci probzelus.zci)) + (files distribution.zci infer_ds_naive.zci infer_ds_streaming.zci infer_ds_streaming_copy.zci infer_ds_streaming_copy_instrumented.zci infer_importance.zci infer_pf.zci probzelus.zci)) diff --git a/probzelus/inference/infer_ds_streaming_copy_instrumented.ml b/probzelus/inference/infer_ds_streaming_copy_instrumented.ml new file mode 100644 index 00000000..511f34eb --- /dev/null +++ b/probzelus/inference/infer_ds_streaming_copy_instrumented.ml @@ -0,0 +1,690 @@ +(* + * Copyright 2018-2020 IBM Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + *) + +(** Inference with delayed sampling *) +open Ztypes +open Types + +(* type 'a random_var = RV : ('b, 'a) Ds_streaming_low_level_instrumented.ds_node -> 'a random_var *) +type 'a random_var = { rv_id : int; } +type ('a, 'b) ds_node = ('a, 'b) ds_graph_node + +(* module Gnodes = struct *) +(* module E = Ephemeron.K1 *) + +(* module M = Hashtbl.Make(struct *) +(* type t = int *) +(* let equal (x: int) (y: int) = (x = y) *) +(* let hash (x: int) = Hashtbl.hash x *) +(* (\* let compare (x:int) (y:int) = compare x y *\) *) +(* end) *) + +(* (\* type t = (Obj.t random_var, (Obj.t, Obj.t) ds_node) E.t M.t ref *\) *) +(* type ephemeron = (Obj.t random_var, unit) E.t *) +(* type t = *) +(* { live_nodes : (ephemeron * (Obj.t, Obj.t) ds_node) M.t; *) +(* mutable ephemeron_pool: ephemeron list; } *) + +(* let create _ = *) +(* { live_nodes = M.create 11; *) +(* ephemeron_pool = []; } *) + +(* let new_ephemeron g = *) +(* begin match g.ephemeron_pool with *) +(* | [] -> E.create () *) +(* | e::p -> g.ephemeron_pool <- p; e *) +(* end *) + +(* let add: type a p. *) +(* t -> a random_var -> (p, a) ds_node -> unit = *) +(* fun g x y -> *) +(* let e = new_ephemeron g in *) +(* E.set_key e (Obj.magic x: Obj.t random_var); *) +(* (\* E.set_data e (Obj.magic y: (Obj.t, Obj.t) ds_node); *\) *) +(* let n = (Obj.magic y: (Obj.t, Obj.t) ds_node) in *) +(* M.add g.live_nodes x.rv_id (e, n) *) + +(* let find_opt: type a p. *) +(* t -> a random_var -> (p, a) ds_node option = *) +(* fun g x -> *) +(* let k = (Obj.magic x: Obj.t random_var).rv_id in *) +(* begin match M.find_opt g.live_nodes k with *) +(* | None -> None *) +(* | Some (e, n) -> Some (Obj.magic n: (p, a) ds_node) *) +(* end *) + +(* let clear: t -> unit = *) +(* fun g -> *) +(* g.ephemeron_pool <- *) +(* M.fold *) +(* (fun _ (e, _) acc -> E.unset_key e; e::acc) *) +(* g.live_nodes *) +(* g.ephemeron_pool; *) +(* M.clear g.live_nodes *) + +(* let clean: t -> unit = *) +(* fun g -> *) +(* M.filter_map_inplace *) +(* (fun _ (e, n) -> *) +(* let b = E.check_key e in *) +(* if not b then begin *) +(* g.ephemeron_pool <- e::g.ephemeron_pool; *) +(* None *) +(* end *) +(* else Some (e, n)) *) +(* g.live_nodes *) + +(* let copy: t -> t -> unit = *) +(* fun src dst -> *) +(* let tbl = Hashtbl.create 11 in *) +(* (\* clean src; *\) *) +(* clear dst; *) +(* M.iter (fun k (e, n) -> *) +(* let e' = new_ephemeron dst in *) +(* begin match E.get_key e with *) +(* | Some x -> E.set_key e' x; *) +(* | _ -> () *) +(* end; *) +(* let n' = Ds_streaming_low_level_instrumented.copy_node tbl n in *) +(* M.add dst.live_nodes k (e', n')) *) +(* src.live_nodes *) +(* end *) + + +module Gnodes = struct + module E = Ephemeron.K1 + + module M = Map.Make(struct + type t = int + let compare (x:int) (y:int) = compare x y + end) + + (* type t = (Obj.t random_var, (Obj.t, Obj.t) ds_node) E.t M.t ref *) + type ephemeron = (Obj.t random_var, unit) E.t + type t = + { mutable live_nodes : (ephemeron * (Obj.t, Obj.t) ds_node) M.t; + mutable ephemeron_pool: ephemeron list; } + + let size g = M.cardinal g.live_nodes + + let create _ = + { live_nodes = M.empty; + ephemeron_pool = []; } + + let new_ephemeron g = + begin match g.ephemeron_pool with + | [] -> E.create () + | e::p -> g.ephemeron_pool <- p; e + end + + let add: type a p. + t -> a random_var -> (p, a) ds_node -> unit = + fun g x y -> + let e = new_ephemeron g in + E.set_key e (Obj.magic x: Obj.t random_var); + (* E.set_data e (Obj.magic y: (Obj.t, Obj.t) ds_node); *) + let n = (Obj.magic y: (Obj.t, Obj.t) ds_node) in + g.live_nodes <- M.add x.rv_id (e, n) g.live_nodes + + let find_opt: type a p. + t -> a random_var -> (p, a) ds_node option = + fun g x -> + let k = (Obj.magic x: Obj.t random_var).rv_id in + begin match M.find_opt k g.live_nodes with + | None -> None + | Some (_e, n) -> Some (Obj.magic n: (p, a) ds_node) + end + + let clear: t -> unit = + fun g -> + g.ephemeron_pool <- + M.fold + (fun _ (e, _) acc -> E.unset_key e; e::acc) + g.live_nodes + g.ephemeron_pool; + g.live_nodes <- M.empty + + let clean: t -> unit = + fun g -> + g.live_nodes <- M.filter + (fun _ (e, _) -> + let b = E.check_key e in + if not b then g.ephemeron_pool <- e::g.ephemeron_pool; + b) + g.live_nodes + + let copy: Ds_streaming_low_level_instrumented.instrumentation -> t -> t -> unit = + fun ctx src dst -> + let tbl = Hashtbl.create 41 in + (*(print_string "-----START COPY-------\n");*) + (*(print_string ("Size pre-clean: " ^ (string_of_int (size src)) ^ "\n"));*) + Gc.full_major (); + clean src; + (*(print_string ("Size post-clean: " ^ (string_of_int (size src)) ^ "\n"));*) + clear dst; + dst.live_nodes <- M.map (fun (e, n) -> + let e' = new_ephemeron dst in + begin match E.get_key e with + | Some x -> E.set_key e' x; + | _ -> () + end; + let n' = Distribution.DS_graph.copy_node tbl n in + (e', n')) + src.live_nodes; + (*(print_string ("Size dst: " ^ (string_of_int (size dst)) ^ "\n"));*) + Ds_streaming_low_level_instrumented.add_all_nodes ctx tbl; + (*(Ds_streaming_low_level_instrumented.print_ins_helper ctx);*) + (*(print_string "-----END COPY-------\n");*) +end + +(* module Gnodes = struct *) +(* module E = Ephemeron.K1 *) + +(* module H = Ephemeron.K1.Make(struct *) +(* (\* module Gnodes = Hashtbl.Make(struct *\) *) +(* type t = Obj.t random_var (\* random_var *\) *) +(* let equal x y = x.rv_id = y.rv_id *) +(* let hash x = Hashtbl.hash x.rv_id *) +(* end) *) + +(* type t = (Obj.t, Obj.t) ds_node H.t *) + +(* let create n = H.create n *) + +(* let add: type a p. *) +(* t -> a random_var -> (p, a) ds_node -> unit = *) +(* fun g x y -> *) +(* let n = (Obj.magic y: (Obj.t, Obj.t) ds_node) in *) +(* H.add g (Obj.magic x: Obj.t random_var) n *) + +(* let find_opt: type a p. *) +(* t -> a random_var -> (p, a) ds_node option = *) +(* fun g x -> *) +(* let k = (Obj.magic x: Obj.t random_var) in *) +(* begin match H.find_opt g k with *) +(* | None -> None *) +(* | Some n -> Some (Obj.magic n: (p, a) ds_node) *) +(* end *) + +(* let clear: t -> unit = H.clear *) + +(* let clean: t -> unit = H.clean *) + +(* let copy: t -> t -> unit = *) +(* fun src dst -> *) +(* let tbl = Hashtbl.create 11 in *) +(* (\* clean src; *\) *) +(* clear dst; *) +(* H.iter *) +(* (fun k n -> H.add dst k (Ds_streaming_low_level_instrumented.copy_node tbl n)) *) +(* src *) +(* end *) + + +type pstate = + { pf_state: Ds_streaming_low_level_instrumented.pstate; + ds_graph: Gnodes.t; } + +let rv_node : type a p. + pstate -> a random_var -> (p, a) ds_graph_node = + fun prob x -> + let g = prob.ds_graph in + begin match Gnodes.find_opt g x with + | None -> + Format.eprintf "Failed %d@." x.rv_id; + assert false + | Some o -> o + end + +let add_random_var: type a p. + pstate -> a random_var -> (p, a) ds_graph_node -> unit = + fun prob rv n -> + let g = prob.ds_graph in + Gnodes.add g rv n + +let rv_kind prob rv = + let n = rv_node prob rv in + Ds_streaming_low_level_instrumented.get_distr_kind n + +let rv_distr prob rv = + let n = rv_node prob rv in + Ds_streaming_low_level_instrumented.get_distr n + +let factor' (prob, s) = Ds_streaming_low_level_instrumented.factor' (prob.pf_state, s) +let factor = + let alloc () = () in + let reset _state = () in + let copy _src _dst = () in + let step _state input = + factor' input + in + Cnode { alloc; reset; copy; step; } + +type _ expr_tree = + | Econst : 'a -> 'a expr_tree + | Ervar : 'a random_var -> 'a expr_tree + | Eadd : float expr * float expr -> float expr_tree + | Emult : float expr * float expr -> float expr_tree + | Eapp : ('a -> 'b) expr * 'a expr -> 'b expr_tree + | Epair : 'a expr * 'b expr -> ('a * 'b) expr_tree + | Earray : 'a expr array -> 'a array expr_tree +and 'a expr = { mutable value : 'a expr_tree; } + +let const : 'a. 'a -> 'a expr = + begin fun v -> + { value = Econst v; } + end + +let add : float expr * float expr -> float expr = + begin fun (e1, e2) -> + begin match e1.value, e2.value with + | Econst x, Econst y -> { value = Econst (x +. y); } + | _ -> { value = Eadd (e1, e2); } + end + end + +let ( +~ ) x y = add (x, y) + +let mult : float expr * float expr -> float expr = + begin fun (e1, e2) -> + begin match e1.value, e2.value with + | Econst x, Econst y -> { value = Econst (x *. y); } + | Ervar _, Econst _ -> { value = Emult(e2, e1); } + | _ -> { value = Emult(e1, e2); } + end + end + +let ( *~ ) x y = mult (x, y) + +let app : type t1 t2. (t1 -> t2) expr * t1 expr -> t2 expr = + begin fun (e1, e2) -> + begin match e1.value, e2.value with + | Econst f, Econst x -> { value = Econst (f x); } + | _ -> { value = Eapp(e1, e2); } + end + end + +let ( @@~ ) f e = app (f, e) + +let pair (e1, e2) = + { value = Epair (e1, e2) } + +let array a = + { value = Earray a } + +let rec eval' : type t. + pstate -> t expr -> t = + begin fun prob e -> + begin match e.value with + | Econst v -> v + | Ervar x -> + let n = rv_node prob x in + let v = Ds_streaming_low_level_instrumented.value n in + e.value <- Econst v; + v + | Eadd (e1, e2) -> + let v = eval' prob e1 +. eval' prob e2 in + e.value <- Econst v; + v + | Emult (e1, e2) -> + let v = eval' prob e1 *. eval' prob e2 in + e.value <- Econst v; + v + | Eapp (e1, e2) -> + let v = (eval' prob e1) (eval' prob e2) in + e.value <- Econst v; + v + | Epair (e1, e2) -> + let v = (eval' prob e1, eval' prob e2) in + e.value <- Econst v; + v + | Earray a -> + Array.map (eval' prob) a + end + end + +let eval = + let alloc () = () in + let reset _state = () in + let copy _src _dst = () in + let step _state (prob, input) = + eval' prob input + in + Cnode { alloc; reset; copy; step; } + + +(* let rec fval : type t. t expr -> t = + begin fun e -> + begin match e.value with + | Econst v -> v + | Ervar (RV x) -> Ds_streaming_low_level_instrumented.fvalue x + | Eadd (e1, e2) -> fval e1 +. fval e2 + | Emult (e1, e2) -> fval e1 *. fval e2 + | Eapp (e1, e2) -> (fval e1) (fval e2) + | Epair (e1, e2) -> (fval e1, fval e2) + end + end *) + +let rec string_of_expr e = + begin match e.value with + | Econst v -> string_of_float v + | Ervar x -> "RV_" ^ string_of_int x.rv_id + | Eadd (e1, e2) -> "(" ^ string_of_expr e1 ^ " + " ^ string_of_expr e2 ^ ")" + | Emult (e1, e2) -> "(" ^ string_of_expr e1 ^ " * " ^ string_of_expr e2 ^ ")" + | Eapp (_, _) -> "App" + end + +(* High level delayed sampling distribution (pdistribution in Haskell) *) +type 'a ds_distribution = + { isample : (pstate -> 'a expr); + iobserve : (pstate * 'a -> unit); } + +let sample = + let alloc () = () in + let reset _state = () in + let copy _src _dst = () in + let step _state (prob, ds_distr) = + ds_distr.isample prob + in + Cnode { alloc; reset; copy; step; } + +let observe = + let alloc () = () in + let reset _state = () in + let copy _src _dst = () in + let step _state (prob, (ds_distr, o)) = + ds_distr.iobserve(prob, o) + in + Cnode { alloc; reset; copy; step; } + +let print_ins = + let alloc () = () in + let reset _state = () in + let copy _src _dst = () in + let step _state (prob, _) = + Ds_streaming_low_level_instrumented.print_ins prob.pf_state + in + Cnode { alloc; reset; copy; step; } + +let of_distribution d = + { isample = (fun _prob -> const (Distribution.draw d)); + iobserve = (fun (prob, obs) -> factor' (prob, Distribution.score(d, obs))); } + +let ds_distr_with_fallback d is iobs = + let dsd = + let state = ref None in + (fun prob -> + begin match !state with + | None -> + let dsd = of_distribution (d prob) in + state := Some dsd; + dsd + | Some dsd -> dsd + end) + in + let is' prob = + begin match is prob with + | None -> (dsd prob).isample prob + | Some x -> x + end + in + let iobs' (prob, obs) = + begin match iobs (prob, obs) with + | None -> (dsd prob).iobserve (prob, obs) + | Some () -> () + end + in + { isample = is'; iobserve = iobs'; } + +(* An affine_expr is either a constant or an affine transformation of a + * random variable *) +type affine_expr = + (* Interpretation (m, x, b) such that the output is m * x + b *) + | AErvar of float * float random_var * float + | AEconst of float + +let rec affine_of_expr : float expr -> affine_expr option = + begin fun expr -> + begin match expr.value with + | Econst v -> Some (AEconst v) + | Ervar var -> Some (AErvar (1., var, 0.)) + | Eadd (e1, e2) -> + begin match (affine_of_expr e1, affine_of_expr e2) with + | (Some (AErvar (m, x, b)), Some (AEconst v)) + | (Some (AEconst v), Some (AErvar (m, x, b))) -> Some (AErvar (m, x, b +. v)) + | _ -> None + end + | Emult (e1, e2) -> + begin match (affine_of_expr e1, affine_of_expr e2) with + | (Some (AErvar (m, x, b)), Some (AEconst v)) + | (Some (AEconst v), Some (AErvar (m, x, b))) -> Some (AErvar (m *. v, x, b *. v)) + | _ -> None + end + | Eapp (_, _) -> None + end + end + +let assume_constant : type a. + pstate -> a mdistr -> a random_var = + fun prob d -> + let n = Ds_streaming_low_level_instrumented.assume_constant prob.pf_state d in + let rv = { rv_id = n.ds_graph_node_id } in + add_random_var prob rv n; + rv + +let assume_conditional : type b c. + pstate -> b random_var -> (b, c) cdistr -> c random_var = + fun prob p d -> + let par = rv_node prob p in + let n = Ds_streaming_low_level_instrumented.assume_conditional prob.pf_state par d in + let rv = { rv_id = n.ds_graph_node_id } in + add_random_var prob rv n; + rv + +let observe_conditional : type b c. + pstate -> b random_var -> (b, c) cdistr -> c -> unit = + fun prob p cdistr obs -> + let tmp_rv = assume_conditional prob p cdistr in + Ds_streaming_low_level_instrumented.observe_with_graft prob.pf_state obs (rv_node prob tmp_rv) + + +(** Gaussian distribution (gaussianPD in Haskell) *) +let gaussian (mu, std) = + let d prob = Distribution.gaussian(eval' prob mu, std) in + let is prob = + begin match affine_of_expr mu with + | Some (AEconst v) -> + let rv = assume_constant prob (Dist_gaussian(v, std)) in + Some { value = (Ervar rv) } + | Some (AErvar (m, x, b)) -> + begin match rv_kind prob x with + | KGaussian -> + let rv = + assume_conditional prob x (AffineMeanGaussian(m, b, std)) + in + Some { value = (Ervar rv) } + | _ -> None + end + | None -> None + end + in + let iobs (prob, obs) = + begin match affine_of_expr mu with + | Some (AEconst _) -> + None + | Some (AErvar (m, x, b)) -> + begin match rv_kind prob x with + | KGaussian -> + Some (observe_conditional prob x (AffineMeanGaussian(m, b, std)) obs) + | _ -> None + end + | None -> None + end + in + ds_distr_with_fallback d is iobs + +(** Beta distribution (betaPD in Haskell) *) +let beta (a, b) = + let d _prob = Distribution.beta(a, b) in + let is prob = + Some { value = Ervar (assume_constant prob (Dist_beta (a, b))) } + in + let iobs (_prob, _obs) = None in + ds_distr_with_fallback d is iobs + +(** Bernoulli distribution (bernoulliPD in Haskell) *) +let bernoulli p = + let d prob = Distribution.bernoulli (eval' prob p) in + let with_beta_prior prob f = + begin match p.value with + | Ervar par -> + begin match rv_kind prob par with + | KBeta -> Some (f par) + | _ -> None + end + | _ -> None + end + in + let is prob = + with_beta_prior prob + (fun par -> + { value = Ervar (assume_conditional prob par CBernoulli) }) + in + let iobs (prob, obs) = + with_beta_prior prob + (fun par -> observe_conditional prob par CBernoulli obs) + in + ds_distr_with_fallback d is iobs + +(** Inference *) + +let rec distribution_of_expr : type a. pstate -> a expr -> a Distribution.t = + fun prob expr -> + begin match expr.value with + | Econst c -> Dist_support [c, 1.] + | Ervar x -> rv_distr prob x + | Eadd (e1, e2) -> + Dist_add (distribution_of_expr prob e1, distribution_of_expr prob e2) + | Emult (e1, e2) -> + Dist_mult (distribution_of_expr prob e1, distribution_of_expr prob e2) + | Eapp (e1, e2) -> + Dist_app (distribution_of_expr prob e1, distribution_of_expr prob e2) + | Epair (e1, e2) -> + Dist_pair (distribution_of_expr prob e1, distribution_of_expr prob e2) + | Earray a -> + Dist_array (Array.map (distribution_of_expr prob) a) + end + +type 'a node_state = + { node_state: 'a; + node_graph: Gnodes.t; + ins : Ds_streaming_low_level_instrumented.instrumentation } + +let infer n (Cnode { alloc; reset; copy; step; }) = + let alloc () = + { node_state = alloc (); + node_graph = Gnodes.create 11; + ins = Ds_streaming_low_level_instrumented.empty_ins () } + in + let reset state = + reset state.node_state; + Gnodes.clear state.node_graph; + Ds_streaming_low_level_instrumented.clear state.ins + in + let step state (pf_prob, x) = + let prob = + { pf_state = Ds_streaming_low_level_instrumented.mk_pstate pf_prob state.ins; + ds_graph = state.node_graph; } + in + let d = distribution_of_expr prob (step state.node_state (prob, x)) in + Gnodes.clean state.node_graph; + d + in + let copy src dst = + copy src.node_state dst.node_state; + Ds_streaming_low_level_instrumented.clear dst.ins; + Gnodes.copy dst.ins src.node_graph dst.node_graph; + (*Ds_streaming_low_level_instrumented.copy_ins src.ins dst.ins*) + in + let Cnode {alloc = infer_alloc; reset = infer_reset; + copy = infer_copy; step = infer_step;} = + Infer_pf.infer n (Cnode { alloc; reset; copy = copy; step; }) + in + let infer_step state i = + Distribution.to_mixture (infer_step state i) + in + Cnode {alloc = infer_alloc; reset = infer_reset; + copy = infer_copy; step = infer_step; } + + +(* +let infer_ess_resample n threshold (Cnode { alloc; reset; copy; step; }) = + let alloc () = + { node_state = alloc (); + node_graph = Gnodes.create 11; } + in + let reset state = + reset state.node_state; + Gnodes.clear state.node_graph + in + let step state (pf_prob, x) = + let prob = + { pf_state = pf_prob; + ds_graph = state.node_graph; } + in + let d = distribution_of_expr prob (step state.node_state (prob, x)) in + Gnodes.clean state.node_graph; + d + in + let copy src dst = + copy src.node_state dst.node_state; + Gnodes.copy src.node_graph dst.node_graph + in + let Cnode {alloc = infer_alloc; reset = infer_reset; + copy = infer_copy; step = infer_step;} = + Infer_pf.infer_ess_resample n threshold + (Cnode { alloc; reset; copy; step; }) + in + let infer_step state i = + Distribution.to_mixture (infer_step state i) + in + Cnode {alloc = infer_alloc; reset = infer_reset; + copy = infer_copy; step = infer_step;} + +let infer_bounded n (Cnode { alloc; reset; copy; step; }) = + let alloc () = + { node_state = alloc (); + node_graph = Gnodes.create 11; } + in + let reset state = + reset state.node_state; + Gnodes.clear state.node_graph + in + let step state (pf_prob, x) = + let prob = + { pf_state = pf_prob; + ds_graph = state.node_graph; } + in + let v = eval' prob (step state.node_state (prob, x)) in + Gnodes.clean state.node_graph; + v + in + let copy src dst = + copy src.node_state dst.node_state; + Gnodes.copy src.node_graph dst.node_graph + in + Infer_pf.infer n (Cnode { alloc; reset; copy; step; }) +*) diff --git a/probzelus/inference/infer_ds_streaming_copy_instrumented.zli b/probzelus/inference/infer_ds_streaming_copy_instrumented.zli new file mode 100644 index 00000000..5be846a6 --- /dev/null +++ b/probzelus/inference/infer_ds_streaming_copy_instrumented.zli @@ -0,0 +1,54 @@ +(* + * Copyright 2018-2020 IBM Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + *) + +(** Inference with delayed sampling *) + +type 'a expr + +val const : 'a -> 'a expr +val plus : float expr * float expr -> float expr +val ( +~ ) : float expr -> float expr -> float expr +val mult : float expr * float expr -> float expr +val ( *~ ) : float expr -> float expr -> float expr +val app : ('a -> 'b) expr * 'a expr -> 'b expr +val ( @@~ ) : ('a -> 'b) expr -> 'a expr -> 'b expr +val pair : 'a expr * 'b expr -> ('a * 'b) expr +val array : 'a expr array -> 'a array expr +val lst : 'a expr list -> 'a list expr +val matrix : 'a expr array array -> 'a array array expr + +val eval : 'a expr ~D~> 'a + +type 'a ds_distribution + +val of_distribution : 'a Distribution.t -> 'a ds_distribution +val gaussian : float expr * float -> float ds_distribution +val beta : float * float -> float ds_distribution +val bernoulli : float expr -> bool ds_distribution + +val factor : float ~D~> unit +val sample : 'a ds_distribution ~D~> 'a expr +val observe : 'a ds_distribution * 'a ~D~> unit +val print_ins : unit ~D~> unit + +val infer : + int -S-> ('a ~D~> 'b expr) -S-> 'a -D-> 'b Distribution.t + +val infer_ess_resample : + int -S-> float -S-> ('a ~D~> 'b expr) -S-> 'a -D-> 'b Distribution.t + +val infer_bounded : + int -S-> ('a ~D~> 'b expr) -S-> 'a -D-> 'b Distribution.t