diff --git a/src/distPlus/distribution/AlgebraicShapeCombination.re b/src/distPlus/distribution/AlgebraicShapeCombination.re index edcf00ad..fd6f7453 100644 --- a/src/distPlus/distribution/AlgebraicShapeCombination.re +++ b/src/distPlus/distribution/AlgebraicShapeCombination.re @@ -110,7 +110,11 @@ let toDiscretePointMassesFromTriangulars = }; let combineShapesContinuousContinuous = - (op: ExpressionTypes.algebraicOperation, s1: DistTypes.xyShape, s2: DistTypes.xyShape) + ( + op: ExpressionTypes.algebraicOperation, + s1: DistTypes.xyShape, + s2: DistTypes.xyShape, + ) : DistTypes.xyShape => { let t1n = s1 |> XYShape.T.length; let t2n = s2 |> XYShape.T.length; @@ -118,10 +122,11 @@ let combineShapesContinuousContinuous = // if we add the two distributions, we should probably use normal filters. // if we multiply the two distributions, we should probably use lognormal filters. let t1m = toDiscretePointMassesFromTriangulars(s1); - let t2m = switch (op) { + let t2m = + switch (op) { | `Divide => toDiscretePointMassesFromTriangulars(~inverse=true, s2) | _ => toDiscretePointMassesFromTriangulars(~inverse=false, s2) - }; + }; let combineMeansFn = switch (op) { @@ -185,27 +190,30 @@ let combineShapesContinuousContinuous = // we now want to create a set of target points. For now, let's just evenly distribute 200 points between // between the outputMinX and outputMaxX let nOut = 300; - let outputXs: array(float) = E.A.Floats.range(outputMinX^, outputMaxX^, nOut); + let outputXs: array(float) = + E.A.Floats.range(outputMinX^, outputMaxX^, nOut); let outputYs: array(float) = Belt.Array.make(nOut, 0.0); // now, for each of the outputYs, accumulate from a Gaussian kernel over each input point. - for (j in 0 to E.A.length(masses) - 1) { // go through all of the result points + for (j in 0 to E.A.length(masses) - 1) { + // go through all of the result points if (variances[j] > 0. && masses[j] > 0.) { - for (i in 0 to E.A.length(outputXs) - 1) { // go through all of the target points + for (i in 0 to E.A.length(outputXs) - 1) { + // go through all of the target points let dx = outputXs[i] -. means[j]; - let contribution = masses[j] *. exp(-. (dx ** 2.) /. (2. *. variances[j])) /. (sqrt(2. *. 3.14159276 *. variances[j])); + let contribution = + masses[j] + *. exp(-. (dx ** 2.) /. (2. *. variances[j])) + /. sqrt(2. *. 3.14159276 *. variances[j]); Belt.Array.set(outputYs, i, outputYs[i] +. contribution) |> ignore; - (); }; - (); - } |> ignore; - (); + }; }; {xs: outputXs, ys: outputYs}; }; -let toDiscretePointMassesFromDiscrete = (s: DistTypes.xyShape): pointMassesWithMoments => { - let n = s |> XYShape.T.length; +let toDiscretePointMassesFromDiscrete = + (s: DistTypes.xyShape): pointMassesWithMoments => { let {xs, ys}: XYShape.T.t = s; let n = E.A.length(xs); @@ -231,7 +239,7 @@ let combineShapesContinuousDiscrete = switch (op) { | `Add - | `Subtract => { + | `Subtract => for (j in 0 to t2n - 1) { // creates a new continuous shape for each one of the discrete points, and collects them in outXYShapes. let dxyShape: array((float, float)) = @@ -240,16 +248,16 @@ let combineShapesContinuousDiscrete = Belt.Array.set( dxyShape, i, - (fn(continuousShape.xs[i], discreteShape.xs[j]), continuousShape.ys[i] *. discreteShape.ys[j]), + (fn(continuousShape.xs[i], discreteShape.xs[j]), + continuousShape.ys[i] *. discreteShape.ys[j]), ) |> ignore; (); }; Belt.Array.set(outXYShapes, j, dxyShape) |> ignore; (); } - } - | `Multiply - | `Divide => { + | `Multiply + | `Divide => for (j in 0 to t2n - 1) { // creates a new continuous shape for each one of the discrete points, and collects them in outXYShapes. let dxyShape: array((float, float)) = @@ -265,7 +273,6 @@ let combineShapesContinuousDiscrete = Belt.Array.set(outXYShapes, j, dxyShape) |> ignore; (); } - } }; outXYShapes diff --git a/src/distPlus/distribution/Shape.re b/src/distPlus/distribution/Shape.re index 3e1dcfeb..01685c3c 100644 --- a/src/distPlus/distribution/Shape.re +++ b/src/distPlus/distribution/Shape.re @@ -27,24 +27,19 @@ let combineAlgebraically = (op: ExpressionTypes.algebraicOperation, t1: t, t2: t): t => { switch (t1, t2) { | (Continuous(m1), Continuous(m2)) => - DistTypes.Continuous( - Continuous.combineAlgebraically(op, m1, m2), - ) + Continuous.combineAlgebraically(op, m1, m2) |> Continuous.T.toShape; | (Continuous(m1), Discrete(m2)) | (Discrete(m2), Continuous(m1)) => - DistTypes.Continuous( - Continuous.combineAlgebraicallyWithDiscrete(op, m1, m2), - ) + Continuous.combineAlgebraicallyWithDiscrete(op, m1, m2) |> Continuous.T.toShape | (Discrete(m1), Discrete(m2)) => - DistTypes.Discrete(Discrete.combineAlgebraically(op, m1, m2)) + Discrete.combineAlgebraically(op, m1, m2) |> Discrete.T.toShape | (m1, m2) => - DistTypes.Mixed( - Mixed.combineAlgebraically( - op, - toMixed(m1), - toMixed(m2), - ), + Mixed.combineAlgebraically( + op, + toMixed(m1), + toMixed(m2), ) + |> Mixed.T.toShape }; }; diff --git a/src/distPlus/expressionTree/ExpressionTreeEvaluator.re b/src/distPlus/expressionTree/ExpressionTreeEvaluator.re index 0a4b7070..fd6d06ee 100644 --- a/src/distPlus/expressionTree/ExpressionTreeEvaluator.re +++ b/src/distPlus/expressionTree/ExpressionTreeEvaluator.re @@ -20,61 +20,15 @@ module AlgebraicCombination = { | _ => Ok(`AlgebraicCombination((operation, t1, t2))) }; - let tryCombination = (n, algebraicOp, t1: node, t2: node) => { - let sampleN = - mapRenderable(Shape.sampleNRendered(n), SymbolicDist.T.sampleN(n)); - switch (sampleN(t1), sampleN(t2)) { - | (Some(a), Some(b)) => - Some( - Belt.Array.zip(a, b) - |> E.A.fmap(((a, b)) => Operation.Algebraic.toFn(algebraicOp, a, b)), - ) - | _ => None - }; - }; - - let renderIfNotRendered = (params, t) => - !renderable(t) - ? switch (render(params, t)) { - | Ok(r) => Ok(r) - | Error(e) => Error(e) - } - : Ok(t); - - let combineAsShapes = - (evaluationParams: evaluationParams, algebraicOp, t1: node, t2: node) => { - let i1 = renderIfNotRendered(evaluationParams, t1); - let i2 = renderIfNotRendered(evaluationParams, t2); - E.R.merge(i1, i2) - |> E.R.bind( - _, - ((a, b)) => { - let samples = - tryCombination( - evaluationParams.samplingInputs.sampleCount, - algebraicOp, - a, - b, - ); - let shape = - samples - |> E.O.fmap( - Samples.T.fromSamples( - ~samplingInputs={ - sampleCount: - Some(evaluationParams.samplingInputs.sampleCount), - outputXYPoints: - Some(evaluationParams.samplingInputs.outputXYPoints), - kernelWidth: evaluationParams.samplingInputs.kernelWidth, - }, - ), - ) - |> E.O.bind(_, (r: RenderTypes.ShapeRenderer.Sampling.outputs) => - r.shape - ) - |> E.O.toResult("No response"); - shape |> E.R.fmap(r => `Normalize(`RenderedDist(r))); - }, + let combinationByRendering = + (evaluationParams, algebraicOp, t1: node, t2: node) + : result(node, string) => { + E.R.merge( + Render.ensureIsRenderedAndGetShape(evaluationParams, t1), + Render.ensureIsRenderedAndGetShape(evaluationParams, t2), + ) + |> E.R.fmap(((a, b)) => + `RenderedDist(Shape.combineAlgebraically(algebraicOp, a, b)) ); }; @@ -92,7 +46,7 @@ module AlgebraicCombination = { _, fun | `SymbolicDist(d) as t => Ok(t) - | _ => combineAsShapes(evaluationParams, algebraicOp, t1, t2), + | _ => combinationByRendering(evaluationParams, algebraicOp, t1, t2), ); }; @@ -103,7 +57,7 @@ module VerticalScaling = { let fn = Operation.Scale.toFn(scaleOp); let integralSumCacheFn = Operation.Scale.toIntegralSumCacheFn(scaleOp); let integralCacheFn = Operation.Scale.toIntegralCacheFn(scaleOp); - let renderedShape = render(evaluationParams, t); + let renderedShape = Render.render(evaluationParams, t); switch (renderedShape, scaleBy) { | (Ok(`RenderedDist(rs)), `SymbolicDist(`Float(sm))) => @@ -125,13 +79,22 @@ module VerticalScaling = { module PointwiseCombination = { let pointwiseAdd = (evaluationParams: evaluationParams, t1: t, t2: t) => { - switch (render(evaluationParams, t1), render(evaluationParams, t2)) { + switch (Render.render(evaluationParams, t1), Render.render(evaluationParams, t2)) { | (Ok(`RenderedDist(rs1)), Ok(`RenderedDist(rs2))) => Ok( `RenderedDist( Shape.combinePointwise( ~integralSumCachesFn=(a, b) => Some(a +. b), - ~integralCachesFn=(a, b) => Some(Continuous.combinePointwise(~distributionType=`CDF, (+.), a, b)), + ~integralCachesFn= + (a, b) => + Some( + Continuous.combinePointwise( + ~distributionType=`CDF, + (+.), + a, + b, + ), + ), (+.), rs1, rs2, @@ -153,7 +116,12 @@ module PointwiseCombination = { }; let operationToLeaf = - (evaluationParams: evaluationParams, pointwiseOp: pointwiseOperation, t1: t, t2: t) => { + ( + evaluationParams: evaluationParams, + pointwiseOp: pointwiseOperation, + t1: t, + t2: t, + ) => { switch (pointwiseOp) { | `Add => pointwiseAdd(evaluationParams, t1, t2) | `Multiply => pointwiseMultiply(evaluationParams, t1, t2) @@ -166,7 +134,7 @@ module Truncate = { switch (leftCutoff, rightCutoff, t) { | (None, None, t) => `Solution(t) | (Some(lc), Some(rc), t) when lc > rc => - `Error("Left truncation bound must be smaller than right bound.") + `Error("Left truncation bound must be smaller than right truncation bound.") | (lc, rc, `SymbolicDist(`Uniform(u))) => `Solution( `SymbolicDist(`Uniform(SymbolicDist.Uniform.truncate(lc, rc, u))), @@ -179,7 +147,7 @@ module Truncate = { (evaluationParams: evaluationParams, leftCutoff, rightCutoff, t) => { // TODO: use named args for xMin/xMax in renderToShape; if we're lucky we can at least get the tail // of a distribution we otherwise wouldn't get at all - switch (render(evaluationParams, t)) { + switch (Render.ensureIsRendered(evaluationParams, t)) { | Ok(`RenderedDist(rs)) => Ok(`RenderedDist(Shape.T.truncate(leftCutoff, rightCutoff, rs))) | Error(e) => Error(e) diff --git a/src/distPlus/expressionTree/ExpressionTypes.re b/src/distPlus/expressionTree/ExpressionTypes.re index 06757d47..a76aee74 100644 --- a/src/distPlus/expressionTree/ExpressionTypes.re +++ b/src/distPlus/expressionTree/ExpressionTypes.re @@ -1,7 +1,13 @@ type algebraicOperation = [ | `Add | `Multiply | `Subtract | `Divide]; type pointwiseOperation = [ | `Add | `Multiply]; type scaleOperation = [ | `Multiply | `Exponentiate | `Log]; -type distToFloatOperation = [ | `Pdf(float) | `Cdf(float) | `Inv(float) | `Mean | `Sample]; +type distToFloatOperation = [ + | `Pdf(float) + | `Cdf(float) + | `Inv(float) + | `Mean + | `Sample +]; module ExpressionTree = { type node = [ @@ -31,26 +37,42 @@ module ExpressionTree = { let evaluateNode = (evaluationParams: evaluationParams) => evaluationParams.evaluateNode(evaluationParams); - let render = (evaluationParams: evaluationParams, r) => - evaluateNode(evaluationParams, `Render(r)); - let evaluateAndRetry = (evaluationParams, fn, node) => node |> evaluationParams.evaluateNode(evaluationParams) |> E.R.bind(_, fn(evaluationParams)); - let renderable = - fun - | `SymbolicDist(_) => true - | `RenderedDist(_) => true - | _ => false; + module Render = { + type t = node; + + let render = (evaluationParams: evaluationParams, r) => + `Render(r) |> evaluateNode(evaluationParams); + + let ensureIsRendered = (params, t) => + switch (t) { + | `RenderedDist(_) => Ok(t) + | _ => + switch (render(params, t)) { + | Ok(`RenderedDist(r)) => Ok(`RenderedDist(r)) + | Ok(_) => Error("Did not render as requested") + | Error(e) => Error(e) + } + }; + + let ensureIsRenderedAndGetShape = (params, t) => + switch (ensureIsRendered(params, t)) { + | Ok(`RenderedDist(r)) => Ok(r) + | Ok(_) => Error("Did not render as requested") + | Error(e) => Error(e) + }; + + let getShape = (item: node) => + switch (item) { + | `RenderedDist(r) => Some(r) + | _ => None + }; + }; - let mapRenderable = (renderedFn, symFn, item: node) => - switch (item) { - | `SymbolicDist(s) => Some(symFn(s)) - | `RenderedDist(r) => Some(renderedFn(r)) - | _ => None - }; }; type simplificationResult = [ diff --git a/src/distPlus/expressionTree/SamplingDistribution.re b/src/distPlus/expressionTree/SamplingDistribution.re new file mode 100644 index 00000000..60886b40 --- /dev/null +++ b/src/distPlus/expressionTree/SamplingDistribution.re @@ -0,0 +1,80 @@ +open ExpressionTypes.ExpressionTree; + +let isSamplingDistribution: node => bool = + fun + | `SymbolicDist(_) => true + | `RenderedDist(_) => true + | _ => false; + +let renderIfIsNotSamplingDistribution = (params, t) => + !isSamplingDistribution(t) + ? switch (Render.render(params, t)) { + | Ok(r) => Ok(r) + | Error(e) => Error(e) + } + : Ok(t); + +let map = (~renderedDistFn, ~symbolicDistFn, node: node) => + node + |> ( + fun + | `RenderedDist(r) => Some(renderedDistFn(r)) + | `SymbolicDist(s) => Some(symbolicDistFn(s)) + | _ => None + ); + +let sampleN = n => + map( + ~renderedDistFn=Shape.sampleNRendered(n), + ~symbolicDistFn=SymbolicDist.T.sampleN(n), + ); + +let getCombinationSamples = (n, algebraicOp, t1: node, t2: node) => { + switch (sampleN(n, t1), sampleN(n, t2)) { + | (Some(a), Some(b)) => + Some( + Belt.Array.zip(a, b) + |> E.A.fmap(((a, b)) => Operation.Algebraic.toFn(algebraicOp, a, b)), + ) + | _ => None + }; +}; + +let combineShapesUsingSampling = + (evaluationParams: evaluationParams, algebraicOp, t1: node, t2: node) => { + let i1 = renderIfIsNotSamplingDistribution(evaluationParams, t1); + let i2 = renderIfIsNotSamplingDistribution(evaluationParams, t2); + E.R.merge(i1, i2) + |> E.R.bind( + _, + ((a, b)) => { + let samples = + getCombinationSamples( + evaluationParams.samplingInputs.sampleCount, + algebraicOp, + a, + b, + ); + + // todo: This bottom part should probably be somewhere else. + let shape = + samples + |> E.O.fmap( + Samples.T.fromSamples( + ~samplingInputs={ + sampleCount: + Some(evaluationParams.samplingInputs.sampleCount), + outputXYPoints: + Some(evaluationParams.samplingInputs.outputXYPoints), + kernelWidth: evaluationParams.samplingInputs.kernelWidth, + }, + ), + ) + |> E.O.bind(_, (r: RenderTypes.ShapeRenderer.Sampling.outputs) => + r.shape + ) + |> E.O.toResult("No response"); + shape |> E.R.fmap(r => `Normalize(`RenderedDist(r))); + }, + ); +};