diff --git a/src/distPlus/expressionTree/ExpressionTreeEvaluator.re b/src/distPlus/expressionTree/ExpressionTreeEvaluator.re index 1047fb8d..c1220506 100644 --- a/src/distPlus/expressionTree/ExpressionTreeEvaluator.re +++ b/src/distPlus/expressionTree/ExpressionTreeEvaluator.re @@ -20,68 +20,18 @@ module AlgebraicCombination = { | _ => Ok(`AlgebraicCombination((operation, t1, t2))) }; - let combinationBySampling = (n, algebraicOp, t1: node, t2: node) => { - let sampleN = - mapRenderable(Shape.sampleNRendered(n), SymbolicDist.T.sampleN(n)); - switch (sampleN(t1), sampleN(t2)) { - | (Some(a), Some(b)) => - Some( - Belt.Array.zip(a, b) - |> E.A.fmap(((a, b)) => Operation.Algebraic.toFn(algebraicOp, a, b)), - ) - | _ => None - }; - }; - let combinationByRendering = (evaluationParams, algebraicOp, t1: node, t2: node) : result(node, string) => { E.R.merge( - renderAndGetShape(evaluationParams, t1), - renderAndGetShape(evaluationParams, t2), + Render.ensureIsRenderedAndGetShape(evaluationParams, t1), + Render.ensureIsRenderedAndGetShape(evaluationParams, t2), ) |> E.R.fmap(((a, b)) => `RenderedDist(Shape.combineAlgebraically(algebraicOp, a, b)) ); }; - let combineShapesUsingSampling = - (evaluationParams: evaluationParams, algebraicOp, t1: node, t2: node) => { - let i1 = renderIfNotRenderable(evaluationParams, t1); - let i2 = renderIfNotRenderable(evaluationParams, t2); - E.R.merge(i1, i2) - |> E.R.bind( - _, - ((a, b)) => { - let samples = - combinationBySampling( - evaluationParams.samplingInputs.sampleCount, - algebraicOp, - a, - b, - ); - let shape = - samples - |> E.O.fmap( - Samples.T.fromSamples( - ~samplingInputs={ - sampleCount: - Some(evaluationParams.samplingInputs.sampleCount), - outputXYPoints: - Some(evaluationParams.samplingInputs.outputXYPoints), - kernelWidth: evaluationParams.samplingInputs.kernelWidth, - }, - ), - ) - |> E.O.bind(_, (r: RenderTypes.ShapeRenderer.Sampling.outputs) => - r.shape - ) - |> E.O.toResult("No response"); - shape |> E.R.fmap(r => `Normalize(`RenderedDist(r))); - }, - ); - }; - let operationToLeaf = ( evaluationParams: evaluationParams, @@ -107,7 +57,7 @@ module VerticalScaling = { let fn = Operation.Scale.toFn(scaleOp); let integralSumCacheFn = Operation.Scale.toIntegralSumCacheFn(scaleOp); let integralCacheFn = Operation.Scale.toIntegralCacheFn(scaleOp); - let renderedShape = render(evaluationParams, t); + let renderedShape = Render.render(evaluationParams, t); switch (renderedShape, scaleBy) { | (Ok(`RenderedDist(rs)), `SymbolicDist(`Float(sm))) => @@ -129,7 +79,7 @@ module VerticalScaling = { module PointwiseCombination = { let pointwiseAdd = (evaluationParams: evaluationParams, t1: t, t2: t) => { - switch (render(evaluationParams, t1), render(evaluationParams, t2)) { + switch (Render.render(evaluationParams, t1), Render.render(evaluationParams, t2)) { | (Ok(`RenderedDist(rs1)), Ok(`RenderedDist(rs2))) => Ok( `RenderedDist( @@ -184,7 +134,7 @@ module Truncate = { switch (leftCutoff, rightCutoff, t) { | (None, None, t) => `Solution(t) | (Some(lc), Some(rc), t) when lc > rc => - `Error("Left truncation bound must be smaller than right bound.") + `Error("Left truncation bound must be smaller than right truncation bound.") | (lc, rc, `SymbolicDist(`Uniform(u))) => `Solution( `SymbolicDist(`Uniform(SymbolicDist.Uniform.truncate(lc, rc, u))), @@ -197,7 +147,7 @@ module Truncate = { (evaluationParams: evaluationParams, leftCutoff, rightCutoff, t) => { // TODO: use named args for xMin/xMax in renderToShape; if we're lucky we can at least get the tail // of a distribution we otherwise wouldn't get at all - switch (render(evaluationParams, t)) { + switch (Render.ensureIsRendered(evaluationParams, t)) { | Ok(`RenderedDist(rs)) => Ok(`RenderedDist(Shape.T.truncate(leftCutoff, rightCutoff, rs))) | Error(e) => Error(e) diff --git a/src/distPlus/expressionTree/ExpressionTypes.re b/src/distPlus/expressionTree/ExpressionTypes.re index 2cee4630..a76aee74 100644 --- a/src/distPlus/expressionTree/ExpressionTypes.re +++ b/src/distPlus/expressionTree/ExpressionTypes.re @@ -37,61 +37,42 @@ module ExpressionTree = { let evaluateNode = (evaluationParams: evaluationParams) => evaluationParams.evaluateNode(evaluationParams); - let render = (evaluationParams: evaluationParams, r) => - evaluateNode(evaluationParams, `Render(r)); - - let renderable = - fun - | `SymbolicDist(_) => true - | `RenderedDist(_) => true - | _ => false; - - let renderIfNotRenderable = (params, t) => - !renderable(t) - ? switch (render(params, t)) { - | Ok(r) => Ok(r) - | Error(e) => Error(e) - } - : Ok(t); - - let renderIfNotRendered = (params, t) => - switch (t) { - | `RenderedDist(_) => Ok(t) - | _ => - switch (render(params, t)) { - | Ok(r) => Ok(r) - | Error(e) => Error(e) - } - }; - let evaluateAndRetry = (evaluationParams, fn, node) => node |> evaluationParams.evaluateNode(evaluationParams) |> E.R.bind(_, fn(evaluationParams)); - let renderedShape = (item: node) => - switch (item) { - | `RenderedDist(r) => Some(r) - | _ => None - }; + module Render = { + type t = node; - let renderAndGetShape = (params, t) => - switch (renderIfNotRendered(params, t)) { - | Ok(`RenderedDist(r)) => Ok(r) - | Error(r) => - Js.log(r); - Error(r); - | Ok(l) => - Js.log(l); - Error("Did not render as requested"); - }; + let render = (evaluationParams: evaluationParams, r) => + `Render(r) |> evaluateNode(evaluationParams); + + let ensureIsRendered = (params, t) => + switch (t) { + | `RenderedDist(_) => Ok(t) + | _ => + switch (render(params, t)) { + | Ok(`RenderedDist(r)) => Ok(`RenderedDist(r)) + | Ok(_) => Error("Did not render as requested") + | Error(e) => Error(e) + } + }; + + let ensureIsRenderedAndGetShape = (params, t) => + switch (ensureIsRendered(params, t)) { + | Ok(`RenderedDist(r)) => Ok(r) + | Ok(_) => Error("Did not render as requested") + | Error(e) => Error(e) + }; + + let getShape = (item: node) => + switch (item) { + | `RenderedDist(r) => Some(r) + | _ => None + }; + }; - let mapRenderable = (renderedFn, symFn, item: node) => - switch (item) { - | `SymbolicDist(s) => Some(symFn(s)) - | `RenderedDist(r) => Some(renderedFn(r)) - | _ => None - }; }; type simplificationResult = [ diff --git a/src/distPlus/expressionTree/SamplingDistribution.re b/src/distPlus/expressionTree/SamplingDistribution.re new file mode 100644 index 00000000..60886b40 --- /dev/null +++ b/src/distPlus/expressionTree/SamplingDistribution.re @@ -0,0 +1,80 @@ +open ExpressionTypes.ExpressionTree; + +let isSamplingDistribution: node => bool = + fun + | `SymbolicDist(_) => true + | `RenderedDist(_) => true + | _ => false; + +let renderIfIsNotSamplingDistribution = (params, t) => + !isSamplingDistribution(t) + ? switch (Render.render(params, t)) { + | Ok(r) => Ok(r) + | Error(e) => Error(e) + } + : Ok(t); + +let map = (~renderedDistFn, ~symbolicDistFn, node: node) => + node + |> ( + fun + | `RenderedDist(r) => Some(renderedDistFn(r)) + | `SymbolicDist(s) => Some(symbolicDistFn(s)) + | _ => None + ); + +let sampleN = n => + map( + ~renderedDistFn=Shape.sampleNRendered(n), + ~symbolicDistFn=SymbolicDist.T.sampleN(n), + ); + +let getCombinationSamples = (n, algebraicOp, t1: node, t2: node) => { + switch (sampleN(n, t1), sampleN(n, t2)) { + | (Some(a), Some(b)) => + Some( + Belt.Array.zip(a, b) + |> E.A.fmap(((a, b)) => Operation.Algebraic.toFn(algebraicOp, a, b)), + ) + | _ => None + }; +}; + +let combineShapesUsingSampling = + (evaluationParams: evaluationParams, algebraicOp, t1: node, t2: node) => { + let i1 = renderIfIsNotSamplingDistribution(evaluationParams, t1); + let i2 = renderIfIsNotSamplingDistribution(evaluationParams, t2); + E.R.merge(i1, i2) + |> E.R.bind( + _, + ((a, b)) => { + let samples = + getCombinationSamples( + evaluationParams.samplingInputs.sampleCount, + algebraicOp, + a, + b, + ); + + // todo: This bottom part should probably be somewhere else. + let shape = + samples + |> E.O.fmap( + Samples.T.fromSamples( + ~samplingInputs={ + sampleCount: + Some(evaluationParams.samplingInputs.sampleCount), + outputXYPoints: + Some(evaluationParams.samplingInputs.outputXYPoints), + kernelWidth: evaluationParams.samplingInputs.kernelWidth, + }, + ), + ) + |> E.O.bind(_, (r: RenderTypes.ShapeRenderer.Sampling.outputs) => + r.shape + ) + |> E.O.toResult("No response"); + shape |> E.R.fmap(r => `Normalize(`RenderedDist(r))); + }, + ); +};