From 539c7cf7831e1a1c2469abbed7eabb05f542a2a4 Mon Sep 17 00:00:00 2001 From: Ozzie Gooen Date: Tue, 29 Mar 2022 15:21:38 -0400 Subject: [PATCH] Trying to change more |> into -> --- .../src/rescript/GenericDist/GenericDist.res | 29 ++++++++++++------- .../GenericDist_GenericOperation.res | 25 ++++++---------- .../squiggle-lang/src/rescript/utility/E.res | 14 +++++++++ 3 files changed, 42 insertions(+), 26 deletions(-) diff --git a/packages/squiggle-lang/src/rescript/GenericDist/GenericDist.res b/packages/squiggle-lang/src/rescript/GenericDist/GenericDist.res index 3a10558a..e956a303 100644 --- a/packages/squiggle-lang/src/rescript/GenericDist/GenericDist.res +++ b/packages/squiggle-lang/src/rescript/GenericDist/GenericDist.res @@ -41,7 +41,7 @@ let operationToFloat = (toPointSet: toPointSetFn, fnName, t) => { switch symbolicSolution { | Some(r) => Ok(r) - | None => toPointSet(t) |> E.R.fmap(PointSetDist.operate(fnName)) + | None => toPointSet(t)->E.R.fmap2(PointSetDist.operate(fnName)) } } @@ -53,6 +53,8 @@ let defaultSamplingInputs: SamplingInputs.samplingInputs = { kernelWidth: None, } +//Todo: If it's a pointSet, but the xyPointLenght is different from what it has, it should change. +// This is tricky because the case of discrete distributions. let toPointSet = (t, xyPointLength): result => { switch t { | #PointSet(pointSet) => Ok(pointSet) @@ -106,7 +108,10 @@ let truncate = Truncate.run /* Given two random variables A and B, this returns the distribution of a new variable that is the result of the operation on A and B. For instance, normal(0, 1) + normal(1, 1) -> normal(1, 2). - In general, this is implemented via convolution. */ + In general, this is implemented via convolution. + + TODO: It would be useful to be able to pass in a paramater to get this to run either with convolution or monte carlo. +*/ module AlgebraicCombination = { let tryAnalyticalSimplification = ( operation: GenericDist_Types.Operation.arithmeticOperation, @@ -174,9 +179,9 @@ module AlgebraicCombination = { | None => switch chooseConvolutionOrMonteCarlo(t1, t2) { | #CalculateWithMonteCarlo => - runMonteCarlo(toSampleSet, algebraicOp, t1, t2) |> E.R.fmap(r => #SampleSet(r)) + runMonteCarlo(toSampleSet, algebraicOp, t1, t2)->E.R.fmap2(r => #SampleSet(r)) | #CalculateWithConvolution => - runConvolution(toPointSet, algebraicOp, t1, t2) |> E.R.fmap(r => #PointSet(r)) + runConvolution(toPointSet, algebraicOp, t1, t2)->E.R.fmap2(r => #PointSet(r)) } } } @@ -190,10 +195,10 @@ let pointwiseCombination = (toPointSet: toPointSetFn, operation, t2: t, t1: t): error, > => { E.R.merge(toPointSet(t1), toPointSet(t2)) - |> E.R.fmap(((t1, t2)) => + ->E.R.fmap2(((t1, t2)) => PointSetDist.combinePointwise(GenericDist_Types.Operation.arithmeticToFn(operation), t1, t2) ) - |> E.R.fmap(r => #PointSet(r)) + ->E.R.fmap2(r => #PointSet(r)) } let pointwiseCombinationFloat = ( @@ -205,7 +210,7 @@ let pointwiseCombinationFloat = ( switch operation { | #Add | #Subtract => Error(GenericDist_Types.DistributionVerticalShiftIsInvalid) | (#Multiply | #Divide | #Exponentiate | #Log) as operation => - toPointSet(t) |> E.R.fmap(t => { + toPointSet(t)->E.R.fmap2(t => { //TODO: Move to PointSet codebase let fn = (secondary, main) => Operation.Scale.toFn(operation, main, secondary) let integralSumCacheFn = Operation.Scale.toIntegralSumCacheFn(operation) @@ -217,9 +222,10 @@ let pointwiseCombinationFloat = ( t, ) }) - } |> E.R.fmap(r => #PointSet(r)) + }->E.R.fmap2(r => #PointSet(r)) } +//Note: The result should always cumulatively sum to 1. let mixture = ( scaleMultiply: scaleMultiplyFn, pointwiseAdd: pointwiseAddFn, @@ -228,9 +234,12 @@ let mixture = ( if E.A.length(values) == 0 { Error(GenericDist_Types.Other("mixture must have at least 1 element")) } else { + let totalWeight = values->E.A.fmap2(E.Tuple2.second)->E.A.Floats.sum let properlyWeightedValues = - values |> E.A.fmap(((dist, weight)) => scaleMultiply(dist, weight)) |> E.A.R.firstErrorOrOpen - properlyWeightedValues |> E.R.bind(_, values => { + values + ->E.A.fmap2(((dist, weight)) => scaleMultiply(dist, weight /. totalWeight)) + ->E.A.R.firstErrorOrOpen + properlyWeightedValues->E.R.bind(values => { values |> Js.Array.sliceFrom(1) |> E.A.fold_left( diff --git a/packages/squiggle-lang/src/rescript/GenericDist/GenericDist_GenericOperation.res b/packages/squiggle-lang/src/rescript/GenericDist/GenericDist_GenericOperation.res index 2923445e..8dbdfe5a 100644 --- a/packages/squiggle-lang/src/rescript/GenericDist/GenericDist_GenericOperation.res +++ b/packages/squiggle-lang/src/rescript/GenericDist/GenericDist_GenericOperation.res @@ -16,7 +16,6 @@ type outputType = [ | #String(string) ] - module Output = { let toDist = (o: outputType) => switch o { @@ -83,36 +82,30 @@ let rec run = (extra, fnName: operation): outputType => { reCall( ~fnName=#fromDist(#toDistCombination(#Pointwise, #Multiply, #Float(weight)), r), (), - ) -> outputToDistResult + )->outputToDistResult let pointwiseAdd = (r1, r2) => reCall( ~fnName=#fromDist(#toDistCombination(#Pointwise, #Add, #Dist(r2)), r1), (), - ) -> outputToDistResult - + )->outputToDistResult let fromDistFn = (subFn: GenericDist_Types.Operation.fromDist, dist: genericDist) => switch subFn { | #toFloat(fnName) => - GenericDist.operationToFloat(toPointSet, fnName, dist) - |> E.R.fmap(r => #Float(r)) - |> fromResult - | #toString => dist -> GenericDist.toString -> (r => #String(r)) + GenericDist.operationToFloat(toPointSet, fnName, dist)->E.R.fmap2(r => #Float(r))->fromResult + | #toString => dist->GenericDist.toString->(r => #String(r)) | #toDist(#consoleLog) => { Js.log2("Console log requested: ", dist) #Dist(dist) } - | #toDist(#normalize) => dist -> GenericDist.normalize -> (r => #Dist(r)) + | #toDist(#normalize) => dist->GenericDist.normalize->(r => #Dist(r)) | #toDist(#truncate(left, right)) => dist |> GenericDist.truncate(toPointSet, left, right) |> E.R.fmap(r => #Dist(r)) |> fromResult | #toDist(#toPointSet) => - dist - -> GenericDist.toPointSet(xyPointLength) - |> E.R.fmap(r => #Dist(#PointSet(r))) - |> fromResult + dist->GenericDist.toPointSet(xyPointLength)->E.R.fmap2(r => #Dist(#PointSet(r)))->fromResult | #toDist(#toSampleSet(n)) => - dist -> GenericDist.sampleN(n) |> E.R.fmap(r => #Dist(#SampleSet(r))) |> fromResult + dist->GenericDist.sampleN(n)->E.R.fmap2(r => #Dist(#SampleSet(r)))->fromResult | #toDistCombination(#Algebraic, _, #Float(_)) => #Error(NotYetImplemented) | #toDistCombination(#Algebraic, operation, #Dist(dist2)) => dist @@ -135,7 +128,7 @@ let rec run = (extra, fnName: operation): outputType => { | #fromDist(subFn, dist) => fromDistFn(subFn, dist) | #fromFloat(subFn, float) => reCall(~fnName=#fromDist(subFn, GenericDist.fromFloat(float)), ()) | #mixture(dists) => - GenericDist.mixture(scaleMultiply, pointwiseAdd, dists) |> E.R.fmap(r => #Dist(r)) |> fromResult + GenericDist.mixture(scaleMultiply, pointwiseAdd, dists)->E.R.fmap2(r => #Dist(r))->fromResult } } @@ -154,5 +147,5 @@ let fmap = ( | (#fromDist(_), _) => Error(Other("Expected dist, got something else")) | (#fromFloat(_), _) => Error(Other("Expected float, got something else")) } - newFnCall |> E.R.fmap(r => run(extra, r)) |> fromResult + newFnCall->E.R.fmap2(r => run(extra, r))->fromResult } diff --git a/packages/squiggle-lang/src/rescript/utility/E.res b/packages/squiggle-lang/src/rescript/utility/E.res index d17850fd..69d8ecc1 100644 --- a/packages/squiggle-lang/src/rescript/utility/E.res +++ b/packages/squiggle-lang/src/rescript/utility/E.res @@ -33,6 +33,17 @@ module U = { let id = e => e } +module Tuple2 = { + let first = (v: ('a, 'b)) => { + let (a, _) = v + a + } + let second = (v: ('a, 'b)) => { + let (_, b) = v + b + } +} + module O = { let dimap = (sFn, rFn, e) => switch e { @@ -137,6 +148,7 @@ module R = { let result = Rationale.Result.result let id = e => e |> result(U.id, U.id) let fmap = Rationale.Result.fmap + let fmap2 = (a,b) => Rationale.Result.fmap(b,a) let bind = Rationale.Result.bind let toExn = Belt.Result.getExn let default = (default, res: Belt.Result.t<'a, 'b>) => @@ -233,6 +245,7 @@ module L = { /* A for Array */ module A = { let fmap = Array.map + let fmap2 = (a,b) => Array.map(b,a) let fmapi = Array.mapi let to_list = Array.to_list let of_list = Array.of_list @@ -405,6 +418,7 @@ module A = { : { let _ = Js.Array.push(element, continuous) } + () })