Moved a bit of code to the SamplingDistribution file

This commit is contained in:
Ozzie Gooen 2020-07-17 23:00:17 +01:00
parent f09caaae5e
commit 6d71722d34
3 changed files with 115 additions and 104 deletions

View File

@ -20,68 +20,18 @@ module AlgebraicCombination = {
| _ => Ok(`AlgebraicCombination((operation, t1, t2))) | _ => Ok(`AlgebraicCombination((operation, t1, t2)))
}; };
let combinationBySampling = (n, algebraicOp, t1: node, t2: node) => {
let sampleN =
mapRenderable(Shape.sampleNRendered(n), SymbolicDist.T.sampleN(n));
switch (sampleN(t1), sampleN(t2)) {
| (Some(a), Some(b)) =>
Some(
Belt.Array.zip(a, b)
|> E.A.fmap(((a, b)) => Operation.Algebraic.toFn(algebraicOp, a, b)),
)
| _ => None
};
};
let combinationByRendering = let combinationByRendering =
(evaluationParams, algebraicOp, t1: node, t2: node) (evaluationParams, algebraicOp, t1: node, t2: node)
: result(node, string) => { : result(node, string) => {
E.R.merge( E.R.merge(
renderAndGetShape(evaluationParams, t1), Render.ensureIsRenderedAndGetShape(evaluationParams, t1),
renderAndGetShape(evaluationParams, t2), Render.ensureIsRenderedAndGetShape(evaluationParams, t2),
) )
|> E.R.fmap(((a, b)) => |> E.R.fmap(((a, b)) =>
`RenderedDist(Shape.combineAlgebraically(algebraicOp, a, b)) `RenderedDist(Shape.combineAlgebraically(algebraicOp, a, b))
); );
}; };
let combineShapesUsingSampling =
(evaluationParams: evaluationParams, algebraicOp, t1: node, t2: node) => {
let i1 = renderIfNotRenderable(evaluationParams, t1);
let i2 = renderIfNotRenderable(evaluationParams, t2);
E.R.merge(i1, i2)
|> E.R.bind(
_,
((a, b)) => {
let samples =
combinationBySampling(
evaluationParams.samplingInputs.sampleCount,
algebraicOp,
a,
b,
);
let shape =
samples
|> E.O.fmap(
Samples.T.fromSamples(
~samplingInputs={
sampleCount:
Some(evaluationParams.samplingInputs.sampleCount),
outputXYPoints:
Some(evaluationParams.samplingInputs.outputXYPoints),
kernelWidth: evaluationParams.samplingInputs.kernelWidth,
},
),
)
|> E.O.bind(_, (r: RenderTypes.ShapeRenderer.Sampling.outputs) =>
r.shape
)
|> E.O.toResult("No response");
shape |> E.R.fmap(r => `Normalize(`RenderedDist(r)));
},
);
};
let operationToLeaf = let operationToLeaf =
( (
evaluationParams: evaluationParams, evaluationParams: evaluationParams,
@ -107,7 +57,7 @@ module VerticalScaling = {
let fn = Operation.Scale.toFn(scaleOp); let fn = Operation.Scale.toFn(scaleOp);
let integralSumCacheFn = Operation.Scale.toIntegralSumCacheFn(scaleOp); let integralSumCacheFn = Operation.Scale.toIntegralSumCacheFn(scaleOp);
let integralCacheFn = Operation.Scale.toIntegralCacheFn(scaleOp); let integralCacheFn = Operation.Scale.toIntegralCacheFn(scaleOp);
let renderedShape = render(evaluationParams, t); let renderedShape = Render.render(evaluationParams, t);
switch (renderedShape, scaleBy) { switch (renderedShape, scaleBy) {
| (Ok(`RenderedDist(rs)), `SymbolicDist(`Float(sm))) => | (Ok(`RenderedDist(rs)), `SymbolicDist(`Float(sm))) =>
@ -129,7 +79,7 @@ module VerticalScaling = {
module PointwiseCombination = { module PointwiseCombination = {
let pointwiseAdd = (evaluationParams: evaluationParams, t1: t, t2: t) => { let pointwiseAdd = (evaluationParams: evaluationParams, t1: t, t2: t) => {
switch (render(evaluationParams, t1), render(evaluationParams, t2)) { switch (Render.render(evaluationParams, t1), Render.render(evaluationParams, t2)) {
| (Ok(`RenderedDist(rs1)), Ok(`RenderedDist(rs2))) => | (Ok(`RenderedDist(rs1)), Ok(`RenderedDist(rs2))) =>
Ok( Ok(
`RenderedDist( `RenderedDist(
@ -184,7 +134,7 @@ module Truncate = {
switch (leftCutoff, rightCutoff, t) { switch (leftCutoff, rightCutoff, t) {
| (None, None, t) => `Solution(t) | (None, None, t) => `Solution(t)
| (Some(lc), Some(rc), t) when lc > rc => | (Some(lc), Some(rc), t) when lc > rc =>
`Error("Left truncation bound must be smaller than right bound.") `Error("Left truncation bound must be smaller than right truncation bound.")
| (lc, rc, `SymbolicDist(`Uniform(u))) => | (lc, rc, `SymbolicDist(`Uniform(u))) =>
`Solution( `Solution(
`SymbolicDist(`Uniform(SymbolicDist.Uniform.truncate(lc, rc, u))), `SymbolicDist(`Uniform(SymbolicDist.Uniform.truncate(lc, rc, u))),
@ -197,7 +147,7 @@ module Truncate = {
(evaluationParams: evaluationParams, leftCutoff, rightCutoff, t) => { (evaluationParams: evaluationParams, leftCutoff, rightCutoff, t) => {
// TODO: use named args for xMin/xMax in renderToShape; if we're lucky we can at least get the tail // TODO: use named args for xMin/xMax in renderToShape; if we're lucky we can at least get the tail
// of a distribution we otherwise wouldn't get at all // of a distribution we otherwise wouldn't get at all
switch (render(evaluationParams, t)) { switch (Render.ensureIsRendered(evaluationParams, t)) {
| Ok(`RenderedDist(rs)) => | Ok(`RenderedDist(rs)) =>
Ok(`RenderedDist(Shape.T.truncate(leftCutoff, rightCutoff, rs))) Ok(`RenderedDist(Shape.T.truncate(leftCutoff, rightCutoff, rs)))
| Error(e) => Error(e) | Error(e) => Error(e)

View File

@ -37,61 +37,42 @@ module ExpressionTree = {
let evaluateNode = (evaluationParams: evaluationParams) => let evaluateNode = (evaluationParams: evaluationParams) =>
evaluationParams.evaluateNode(evaluationParams); evaluationParams.evaluateNode(evaluationParams);
let render = (evaluationParams: evaluationParams, r) =>
evaluateNode(evaluationParams, `Render(r));
let renderable =
fun
| `SymbolicDist(_) => true
| `RenderedDist(_) => true
| _ => false;
let renderIfNotRenderable = (params, t) =>
!renderable(t)
? switch (render(params, t)) {
| Ok(r) => Ok(r)
| Error(e) => Error(e)
}
: Ok(t);
let renderIfNotRendered = (params, t) =>
switch (t) {
| `RenderedDist(_) => Ok(t)
| _ =>
switch (render(params, t)) {
| Ok(r) => Ok(r)
| Error(e) => Error(e)
}
};
let evaluateAndRetry = (evaluationParams, fn, node) => let evaluateAndRetry = (evaluationParams, fn, node) =>
node node
|> evaluationParams.evaluateNode(evaluationParams) |> evaluationParams.evaluateNode(evaluationParams)
|> E.R.bind(_, fn(evaluationParams)); |> E.R.bind(_, fn(evaluationParams));
let renderedShape = (item: node) => module Render = {
switch (item) { type t = node;
| `RenderedDist(r) => Some(r)
| _ => None
};
let renderAndGetShape = (params, t) => let render = (evaluationParams: evaluationParams, r) =>
switch (renderIfNotRendered(params, t)) { `Render(r) |> evaluateNode(evaluationParams);
| Ok(`RenderedDist(r)) => Ok(r)
| Error(r) => let ensureIsRendered = (params, t) =>
Js.log(r); switch (t) {
Error(r); | `RenderedDist(_) => Ok(t)
| Ok(l) => | _ =>
Js.log(l); switch (render(params, t)) {
Error("Did not render as requested"); | Ok(`RenderedDist(r)) => Ok(`RenderedDist(r))
}; | Ok(_) => Error("Did not render as requested")
| Error(e) => Error(e)
}
};
let ensureIsRenderedAndGetShape = (params, t) =>
switch (ensureIsRendered(params, t)) {
| Ok(`RenderedDist(r)) => Ok(r)
| Ok(_) => Error("Did not render as requested")
| Error(e) => Error(e)
};
let getShape = (item: node) =>
switch (item) {
| `RenderedDist(r) => Some(r)
| _ => None
};
};
let mapRenderable = (renderedFn, symFn, item: node) =>
switch (item) {
| `SymbolicDist(s) => Some(symFn(s))
| `RenderedDist(r) => Some(renderedFn(r))
| _ => None
};
}; };
type simplificationResult = [ type simplificationResult = [

View File

@ -0,0 +1,80 @@
open ExpressionTypes.ExpressionTree;
let isSamplingDistribution: node => bool =
fun
| `SymbolicDist(_) => true
| `RenderedDist(_) => true
| _ => false;
let renderIfIsNotSamplingDistribution = (params, t) =>
!isSamplingDistribution(t)
? switch (Render.render(params, t)) {
| Ok(r) => Ok(r)
| Error(e) => Error(e)
}
: Ok(t);
let map = (~renderedDistFn, ~symbolicDistFn, node: node) =>
node
|> (
fun
| `RenderedDist(r) => Some(renderedDistFn(r))
| `SymbolicDist(s) => Some(symbolicDistFn(s))
| _ => None
);
let sampleN = n =>
map(
~renderedDistFn=Shape.sampleNRendered(n),
~symbolicDistFn=SymbolicDist.T.sampleN(n),
);
let getCombinationSamples = (n, algebraicOp, t1: node, t2: node) => {
switch (sampleN(n, t1), sampleN(n, t2)) {
| (Some(a), Some(b)) =>
Some(
Belt.Array.zip(a, b)
|> E.A.fmap(((a, b)) => Operation.Algebraic.toFn(algebraicOp, a, b)),
)
| _ => None
};
};
let combineShapesUsingSampling =
(evaluationParams: evaluationParams, algebraicOp, t1: node, t2: node) => {
let i1 = renderIfIsNotSamplingDistribution(evaluationParams, t1);
let i2 = renderIfIsNotSamplingDistribution(evaluationParams, t2);
E.R.merge(i1, i2)
|> E.R.bind(
_,
((a, b)) => {
let samples =
getCombinationSamples(
evaluationParams.samplingInputs.sampleCount,
algebraicOp,
a,
b,
);
// todo: This bottom part should probably be somewhere else.
let shape =
samples
|> E.O.fmap(
Samples.T.fromSamples(
~samplingInputs={
sampleCount:
Some(evaluationParams.samplingInputs.sampleCount),
outputXYPoints:
Some(evaluationParams.samplingInputs.outputXYPoints),
kernelWidth: evaluationParams.samplingInputs.kernelWidth,
},
),
)
|> E.O.bind(_, (r: RenderTypes.ShapeRenderer.Sampling.outputs) =>
r.shape
)
|> E.O.toResult("No response");
shape |> E.R.fmap(r => `Normalize(`RenderedDist(r)));
},
);
};