From 2ec1bfd068acf8d7d29562cb6420890d219e5458 Mon Sep 17 00:00:00 2001 From: Ozzie Gooen Date: Sat, 26 Mar 2022 22:06:19 -0400 Subject: [PATCH] Added Truncate to sci.res --- packages/squiggle-lang/src/rescript/sci.res | 113 ++++++++++++++------ 1 file changed, 83 insertions(+), 30 deletions(-) diff --git a/packages/squiggle-lang/src/rescript/sci.res b/packages/squiggle-lang/src/rescript/sci.res index ddcbf182..bf6cc03e 100644 --- a/packages/squiggle-lang/src/rescript/sci.res +++ b/packages/squiggle-lang/src/rescript/sci.res @@ -1,3 +1,5 @@ +//TODO: multimodal, add interface, split up a little bit, test somehow, track performance, refactor sampleSet, refactor ASTEvaluator.res. + type error = | NeedsPointSetConversion | InputsNeedPointSetConversion @@ -18,7 +20,7 @@ module OperationType = { | #Pointwise ] - type combination = [ + type arithmeticOperation = [ | #Add | #Multiply | #Subtract @@ -27,8 +29,8 @@ module OperationType = { | #Log ] - let combinationToFn = (combination: combination) => - switch combination { + let arithmeticToFn = (arithmetic: arithmeticOperation) => + switch arithmetic { | #Add => \"+." | #Multiply => \"*." | #Subtract => \"-." @@ -49,6 +51,7 @@ module OperationType = { | #normalize | #toPointSet | #toSampleSet(int) + | #truncate(option, option) ] type toFloatArray = [ @@ -58,7 +61,7 @@ module OperationType = { type t = [ | #toFloat(toFloat) | #toDist(toDist) - | #toDistCombination(direction, combination, [#Dist(genericDist) | #Float(float)]) + | #toDistCombination(direction, arithmeticOperation, [#Dist(genericDist) | #Float(float)]) ] } @@ -76,6 +79,14 @@ module T = { } } + let normalize = (t: t) => { + switch t { + | #PointSet(r) => #PointSet(PointSetDist.T.normalize(r)) + | #Symbolic(_) => t + | #SampleSet(_) => t + } + } + let toFloat = (toPointSet: toPointSetFn, fnName, t: genericDist): result => { switch t { | #Symbolic(r) if Belt.Result.isOk(SymbolicDist.T.operate(fnName, r)) => @@ -117,10 +128,46 @@ module T = { } } + module Truncate = { + let trySymbolicSimplification = (leftCutoff, rightCutoff, t): option => + switch (leftCutoff, rightCutoff, t) { + | (None, None, _) => None + | (lc, rc, #Symbolic(#Uniform(u))) if lc < rc => + Some(#Symbolic(#Uniform(SymbolicDist.Uniform.truncate(lc, rc, u)))) + | _ => None + } + + let run = ( + toPointSet: toPointSetFn, + leftCutoff: option, + rightCutoff: option, + t: t, + ): result => { + let doesNotNeedCutoff = E.O.isNone(leftCutoff) && E.O.isNone(rightCutoff) + if doesNotNeedCutoff { + Ok(t) + } else { + switch trySymbolicSimplification(leftCutoff, rightCutoff, t) { + | Some(r) => Ok(r) + | None => + toPointSet(t) |> E.R.fmap(t => + #PointSet(PointSetDist.T.truncate(leftCutoff, rightCutoff, t)) + ) + } + } + } + } + + /* Given two random variables A and B, this returns the distribution + of a new variable that is the result of the operation on A and B. + For instance, normal(0, 1) + normal(1, 1) -> normal(1, 2). + In general, this is implemented via convolution. */ module AlgebraicCombination = { - let tryAnalyticalSimplification = (operation: OperationType.combination, t1: t, t2: t): option< - result, - > => + let tryAnalyticalSimplification = ( + operation: OperationType.arithmeticOperation, + t1: t, + t2: t, + ): option> => switch (operation, t1, t2) { | (operation, #Symbolic(d1), #Symbolic(d2)) => switch SymbolicDist.T.tryAnalyticalSimplification(d1, d2, operation) { @@ -133,7 +180,7 @@ module T = { let runConvolution = ( toPointSet: toPointSetFn, - operation: OperationType.combination, + operation: OperationType.arithmeticOperation, t1: t, t2: t, ) => @@ -143,7 +190,7 @@ module T = { let runMonteCarlo = ( toSampleSet: toSampleSetFn, - operation: OperationType.combination, + operation: OperationType.arithmeticOperation, t1: t, t2: t, ) => { @@ -163,7 +210,7 @@ module T = { | _ => 1000 } - let chooseConvolutionOrMonteCarlo = (t1: t, t2: t) => + let chooseConvolutionOrMonteCarlo = (t2: t, t1: t) => expectedConvolutionCost(t1) * expectedConvolutionCost(t2) > 10000 ? #CalculateWithMonteCarlo : #CalculateWithConvolution @@ -189,22 +236,23 @@ module T = { } } - let pointwiseCombination = (toPointSet: toPointSetFn, operation, t1: t, t2: t): result< + //TODO: Add faster pointwiseCombine fn + let pointwiseCombination = (toPointSet: toPointSetFn, operation, t2: t, t1: t): result< t, error, > => { E.R.merge(toPointSet(t1), toPointSet(t2)) |> E.R.fmap(((t1, t2)) => - PointSetDist.combinePointwise(OperationType.combinationToFn(operation), t1, t2) + PointSetDist.combinePointwise(OperationType.arithmeticToFn(operation), t1, t2) ) |> E.R.fmap(r => #PointSet(r)) } let pointwiseCombinationFloat = ( toPointSet: toPointSetFn, - operation: OperationType.combination, - t: t, + operation: OperationType.arithmeticOperation, f: float, + t: t, ): result => { switch operation { | #Add | #Subtract => Error(DistributionVerticalShiftIsInvalid) @@ -259,36 +307,41 @@ module OmniRunner = { switch reCall(~value=r, ~fnName=#toDist(#toPointSet), ()) { | #Dist(#PointSet(p)) => Ok(p) | #Error(r) => Error(r) - | _ => Error(Other("Impossible error")) + | _ => Error(ImpossiblePath) } } let toSampleSet = r => { switch reCall(~value=r, ~fnName=#toDist(#toSampleSet(sampleCount)), ()) { | #Dist(#SampleSet(p)) => Ok(p) | #Error(r) => Error(r) - | _ => Error(Other("Impossible error")) + | _ => Error(ImpossiblePath) } } - switch (fnName, value) { + switch fnName { // | (#toFloat(n), v) => toFloat(toPointSet, v, n) - | (#toFloat(fnName), _) => + | #toFloat(fnName) => T.toFloat(toPointSet, fnName, value) |> E.R.fmap(r => #Float(r)) |> fromResult - | (#toDist(#normalize), #PointSet(r)) => #Dist(#PointSet(PointSetDist.T.normalize(r))) - | (#toDist(#normalize), #Symbolic(_)) => #Dist(value) - | (#toDist(#normalize), #SampleSet(_)) => #Dist(value) - | (#toDist(#toPointSet), _) => + | #toDist(#normalize) => value |> T.normalize |> (r => #Dist(r)) + | #toDist(#truncate(left, right)) => + value |> T.Truncate.run(toPointSet, left, right) |> E.R.fmap(r => #Dist(r)) |> fromResult + | #toDist(#toPointSet) => value |> T.toPointSet(xyPointLength) |> E.R.fmap(r => #Dist(#PointSet(r))) |> fromResult - | (#toDist(#toSampleSet(n)), _) => + | #toDist(#toSampleSet(n)) => value |> T.sampleN(n) |> E.R.fmap(r => #Dist(#SampleSet(r))) |> fromResult - | (#toDistCombination(#Algebraic, _, #Float(_)), _) => #Error(NotYetImplemented) - | (#toDistCombination(#Algebraic, operation, #Dist(p2)), p1) => - T.AlgebraicCombination.run(toPointSet, toSampleSet, operation, p1, p2) + | #toDistCombination(#Algebraic, _, #Float(_)) => #Error(NotYetImplemented) + | #toDistCombination(#Algebraic, operation, #Dist(value2)) => + value + |> T.AlgebraicCombination.run(toPointSet, toSampleSet, operation, value2) |> E.R.fmap(r => #Dist(r)) |> fromResult - | (#toDistCombination(#Pointwise, operation, #Dist(p2)), p1) => - T.pointwiseCombination(toPointSet, operation, p1, p2) |> E.R.fmap(r => #Dist(r)) |> fromResult - | (#toDistCombination(#Pointwise, operation, #Float(f)), _) => - T.pointwiseCombinationFloat(toPointSet, operation, value, f) + | #toDistCombination(#Pointwise, operation, #Dist(value2)) => + value + |> T.pointwiseCombination(toPointSet, operation, value2) + |> E.R.fmap(r => #Dist(r)) + |> fromResult + | #toDistCombination(#Pointwise, operation, #Float(f)) => + value + |> T.pointwiseCombinationFloat(toPointSet, operation, f) |> E.R.fmap(r => #Dist(r)) |> fromResult }