From c5afb2d867ab6c440c636e600f6665198defa173 Mon Sep 17 00:00:00 2001 From: Ozzie Gooen Date: Sat, 26 Mar 2022 16:56:56 -0400 Subject: [PATCH] Fleshed out AlgebraicCombination --- .../AlgebraicShapeCombination.res | 2 + .../rescript/pointSetDist/PointSetDist.res | 3 +- packages/squiggle-lang/src/rescript/sci.res | 117 ++++++++++++++---- 3 files changed, 96 insertions(+), 26 deletions(-) diff --git a/packages/squiggle-lang/src/rescript/pointSetDist/AlgebraicShapeCombination.res b/packages/squiggle-lang/src/rescript/pointSetDist/AlgebraicShapeCombination.res index e08d8887..3c298b18 100644 --- a/packages/squiggle-lang/src/rescript/pointSetDist/AlgebraicShapeCombination.res +++ b/packages/squiggle-lang/src/rescript/pointSetDist/AlgebraicShapeCombination.res @@ -115,6 +115,7 @@ let combineShapesContinuousContinuous = ( | #Multiply => (m1, m2) => m1 *. m2 | #Divide => (m1, mInv2) => m1 *. mInv2 | #Exponentiate => (m1, mInv2) => m1 ** mInv2 + | #Log => (m1, m2) => log(m1) /. log(m2) } // note: here, mInv2 = mean(1 / t2) ~= 1 / mean(t2) // TODO: I don't know what the variances are for exponentatiation @@ -232,6 +233,7 @@ let combineShapesContinuousDiscrete = ( } | #Multiply | #Exponentiate + | #Log | #Divide => for j in 0 to t2n - 1 { // creates a new continuous shape for each one of the discrete points, and collects them in outXYShapes. diff --git a/packages/squiggle-lang/src/rescript/pointSetDist/PointSetDist.res b/packages/squiggle-lang/src/rescript/pointSetDist/PointSetDist.res index 4a9eda9e..59bead6b 100644 --- a/packages/squiggle-lang/src/rescript/pointSetDist/PointSetDist.res +++ b/packages/squiggle-lang/src/rescript/pointSetDist/PointSetDist.res @@ -41,7 +41,8 @@ let combineAlgebraically = (op: Operation.algebraicOperation, t1: t, t2: t): t = | (Continuous(m1), Discrete(m2)) | (Discrete(m2), Continuous(m1)) => Continuous.combineAlgebraicallyWithDiscrete(op, m1, m2) |> Continuous.T.toPointSetDist - | (Discrete(m1), Discrete(m2)) => Discrete.combineAlgebraically(op, m1, m2) |> Discrete.T.toPointSetDist + | (Discrete(m1), Discrete(m2)) => + Discrete.combineAlgebraically(op, m1, m2) |> Discrete.T.toPointSetDist | (m1, m2) => Mixed.combineAlgebraically(op, toMixed(m1), toMixed(m2)) |> Mixed.T.toPointSetDist } diff --git a/packages/squiggle-lang/src/rescript/sci.res b/packages/squiggle-lang/src/rescript/sci.res index 66a5605a..ddcbf182 100644 --- a/packages/squiggle-lang/src/rescript/sci.res +++ b/packages/squiggle-lang/src/rescript/sci.res @@ -55,12 +55,6 @@ module OperationType = { | #Sample(int) ] - type scale = [ - | #Multiply - | #Exponentiate - | #Log - ] - type t = [ | #toFloat(toFloat) | #toDist(toDist) @@ -73,6 +67,7 @@ type operation = OperationType.t module T = { type t = genericDist type toPointSetFn = genericDist => result + type toSampleSetFn = genericDist => result, error> let sampleN = (n, t: t) => { switch t { | #PointSet(r) => Ok(PointSetDist.sampleNRendered(n, r)) @@ -81,7 +76,7 @@ module T = { } } - let toFloat = (toPointSet: toPointSetFn, fnName, t: genericDist) => { + let toFloat = (toPointSet: toPointSetFn, fnName, t: genericDist): result => { switch t { | #Symbolic(r) if Belt.Result.isOk(SymbolicDist.T.operate(fnName, r)) => switch SymbolicDist.T.operate(fnName, r) { @@ -104,7 +99,7 @@ module T = { kernelWidth: None, } - let toPointSet = (xyPointLength, t: t) => { + let toPointSet = (xyPointLength, t: t): result => { switch t { | #PointSet(pointSet) => Ok(pointSet) | #Symbolic(r) => Ok(SymbolicDist.T.toPointSetDist(xyPointLength, r)) @@ -122,16 +117,82 @@ module T = { } } - let algebraicCombination = (operation, sampleCount, dist1: t, dist2: t) => { - let dist1 = sampleN(sampleCount, dist1) - let dist2 = sampleN(sampleCount, dist2) - let samples = E.R.merge(dist1, dist2) |> E.R.fmap(((d1, d2)) => { - Belt.Array.zip(d1, d2) |> E.A.fmap(((a, b)) => Operation.Algebraic.toFn(operation, a, b)) - }) - samples |> E.R.fmap(r => #SampleSet(r)) + module AlgebraicCombination = { + let tryAnalyticalSimplification = (operation: OperationType.combination, t1: t, t2: t): option< + result, + > => + switch (operation, t1, t2) { + | (operation, #Symbolic(d1), #Symbolic(d2)) => + switch SymbolicDist.T.tryAnalyticalSimplification(d1, d2, operation) { + | #AnalyticalSolution(symbolicDist) => Some(Ok(symbolicDist)) + | #Error(er) => Some(Error(er)) + | #NoSolution => None + } + | _ => None + } + + let runConvolution = ( + toPointSet: toPointSetFn, + operation: OperationType.combination, + t1: t, + t2: t, + ) => + E.R.merge(toPointSet(t1), toPointSet(t2)) |> E.R.fmap(((a, b)) => + PointSetDist.combineAlgebraically(operation, a, b) + ) + + let runMonteCarlo = ( + toSampleSet: toSampleSetFn, + operation: OperationType.combination, + t1: t, + t2: t, + ) => { + E.R.merge(toSampleSet(t1), toSampleSet(t2)) |> E.R.fmap(((a, b)) => { + Belt.Array.zip(a, b) |> E.A.fmap(((a, b)) => Operation.Algebraic.toFn(operation, a, b)) + }) + } + + //I'm (Ozzie) really just guessing here, very little idea what's best + let expectedConvolutionCost: t => int = x => + switch x { + | #Symbolic(#Float(_)) => 1 + | #Symbolic(_) => 1000 + | #PointSet(Discrete(m)) => m.xyShape |> XYShape.T.length + | #PointSet(Mixed(_)) => 1000 + | #PointSet(Continuous(_)) => 1000 + | _ => 1000 + } + + let chooseConvolutionOrMonteCarlo = (t1: t, t2: t) => + expectedConvolutionCost(t1) * expectedConvolutionCost(t2) > 10000 + ? #CalculateWithMonteCarlo + : #CalculateWithConvolution + + let run = ( + toPointSet: toPointSetFn, + toSampleSet: toSampleSetFn, + algebraicOp, + t1: t, + t2: t, + ): result => { + switch tryAnalyticalSimplification(algebraicOp, t1, t2) { + | Some(Ok(symbolicDist)) => Ok(#Symbolic(symbolicDist)) + | Some(Error(e)) => Error(Other(e)) + | None => + switch chooseConvolutionOrMonteCarlo(t1, t2) { + | #CalculateWithMonteCarlo => + runMonteCarlo(toSampleSet, algebraicOp, t1, t2) |> E.R.fmap(r => #SampleSet(r)) + | #CalculateWithConvolution => + runConvolution(toPointSet, algebraicOp, t1, t2) |> E.R.fmap(r => #PointSet(r)) + } + } + } } - let pointwiseCombination = (toPointSet: toPointSetFn, operation, t1: t, t2: t) => { + let pointwiseCombination = (toPointSet: toPointSetFn, operation, t1: t, t2: t): result< + t, + error, + > => { E.R.merge(toPointSet(t1), toPointSet(t2)) |> E.R.fmap(((t1, t2)) => PointSetDist.combinePointwise(OperationType.combinationToFn(operation), t1, t2) @@ -144,11 +205,12 @@ module T = { operation: OperationType.combination, t: t, f: float, - ) => { + ): result => { switch operation { | #Add | #Subtract => Error(DistributionVerticalShiftIsInvalid) | (#Multiply | #Divide | #Exponentiate | #Log) as operation => toPointSet(t) |> E.R.fmap(t => { + //TODO: Move to PointSet codebase let fn = (secondary, main) => Operation.Scale.toFn(operation, main, secondary) let integralSumCacheFn = Operation.Scale.toIntegralSumCacheFn(operation) let integralCacheFn = Operation.Scale.toIntegralCacheFn(operation) @@ -159,7 +221,7 @@ module T = { t, ) }) - } + } |> E.R.fmap(r => #PointSet(r)) } } @@ -188,10 +250,10 @@ module OmniRunner = { | Error(e) => #Error(e) } - let rec applyFnInternal = (wrapped: wrapped, fnName: operation): outputType => { + let rec run = (wrapped: wrapped, fnName: operation): outputType => { let (value, {sampleCount, xyPointLength} as extra) = wrapped let reCall = (~value=value, ~extra=extra, ~fnName=fnName, ()) => { - applyFnInternal((value, extra), fnName) + run((value, extra), fnName) } let toPointSet = r => { switch reCall(~value=r, ~fnName=#toDist(#toPointSet), ()) { @@ -200,8 +262,14 @@ module OmniRunner = { | _ => Error(Other("Impossible error")) } } - let toPointSetAndReCall = v => toPointSet(v) |> E.R.fmap(r => reCall(~value=#PointSet(r), ())) - let newVal: outputType = switch (fnName, value) { + let toSampleSet = r => { + switch reCall(~value=r, ~fnName=#toDist(#toSampleSet(sampleCount)), ()) { + | #Dist(#SampleSet(p)) => Ok(p) + | #Error(r) => Error(r) + | _ => Error(Other("Impossible error")) + } + } + switch (fnName, value) { // | (#toFloat(n), v) => toFloat(toPointSet, v, n) | (#toFloat(fnName), _) => T.toFloat(toPointSet, fnName, value) |> E.R.fmap(r => #Float(r)) |> fromResult @@ -214,17 +282,16 @@ module OmniRunner = { value |> T.sampleN(n) |> E.R.fmap(r => #Dist(#SampleSet(r))) |> fromResult | (#toDistCombination(#Algebraic, _, #Float(_)), _) => #Error(NotYetImplemented) | (#toDistCombination(#Algebraic, operation, #Dist(p2)), p1) => - T.algebraicCombination(operation, sampleCount, p1, p2) + T.AlgebraicCombination.run(toPointSet, toSampleSet, operation, p1, p2) |> E.R.fmap(r => #Dist(r)) |> fromResult | (#toDistCombination(#Pointwise, operation, #Dist(p2)), p1) => T.pointwiseCombination(toPointSet, operation, p1, p2) |> E.R.fmap(r => #Dist(r)) |> fromResult | (#toDistCombination(#Pointwise, operation, #Float(f)), _) => T.pointwiseCombinationFloat(toPointSet, operation, value, f) - |> E.R.fmap(r => #Dist(#PointSet(r))) + |> E.R.fmap(r => #Dist(r)) |> fromResult } - newVal } }