Moved a bit of code to the SamplingDistribution file
This commit is contained in:
parent
f09caaae5e
commit
6d71722d34
|
@ -20,68 +20,18 @@ module AlgebraicCombination = {
|
||||||
| _ => Ok(`AlgebraicCombination((operation, t1, t2)))
|
| _ => Ok(`AlgebraicCombination((operation, t1, t2)))
|
||||||
};
|
};
|
||||||
|
|
||||||
let combinationBySampling = (n, algebraicOp, t1: node, t2: node) => {
|
|
||||||
let sampleN =
|
|
||||||
mapRenderable(Shape.sampleNRendered(n), SymbolicDist.T.sampleN(n));
|
|
||||||
switch (sampleN(t1), sampleN(t2)) {
|
|
||||||
| (Some(a), Some(b)) =>
|
|
||||||
Some(
|
|
||||||
Belt.Array.zip(a, b)
|
|
||||||
|> E.A.fmap(((a, b)) => Operation.Algebraic.toFn(algebraicOp, a, b)),
|
|
||||||
)
|
|
||||||
| _ => None
|
|
||||||
};
|
|
||||||
};
|
|
||||||
|
|
||||||
let combinationByRendering =
|
let combinationByRendering =
|
||||||
(evaluationParams, algebraicOp, t1: node, t2: node)
|
(evaluationParams, algebraicOp, t1: node, t2: node)
|
||||||
: result(node, string) => {
|
: result(node, string) => {
|
||||||
E.R.merge(
|
E.R.merge(
|
||||||
renderAndGetShape(evaluationParams, t1),
|
Render.ensureIsRenderedAndGetShape(evaluationParams, t1),
|
||||||
renderAndGetShape(evaluationParams, t2),
|
Render.ensureIsRenderedAndGetShape(evaluationParams, t2),
|
||||||
)
|
)
|
||||||
|> E.R.fmap(((a, b)) =>
|
|> E.R.fmap(((a, b)) =>
|
||||||
`RenderedDist(Shape.combineAlgebraically(algebraicOp, a, b))
|
`RenderedDist(Shape.combineAlgebraically(algebraicOp, a, b))
|
||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
|
||||||
let combineShapesUsingSampling =
|
|
||||||
(evaluationParams: evaluationParams, algebraicOp, t1: node, t2: node) => {
|
|
||||||
let i1 = renderIfNotRenderable(evaluationParams, t1);
|
|
||||||
let i2 = renderIfNotRenderable(evaluationParams, t2);
|
|
||||||
E.R.merge(i1, i2)
|
|
||||||
|> E.R.bind(
|
|
||||||
_,
|
|
||||||
((a, b)) => {
|
|
||||||
let samples =
|
|
||||||
combinationBySampling(
|
|
||||||
evaluationParams.samplingInputs.sampleCount,
|
|
||||||
algebraicOp,
|
|
||||||
a,
|
|
||||||
b,
|
|
||||||
);
|
|
||||||
let shape =
|
|
||||||
samples
|
|
||||||
|> E.O.fmap(
|
|
||||||
Samples.T.fromSamples(
|
|
||||||
~samplingInputs={
|
|
||||||
sampleCount:
|
|
||||||
Some(evaluationParams.samplingInputs.sampleCount),
|
|
||||||
outputXYPoints:
|
|
||||||
Some(evaluationParams.samplingInputs.outputXYPoints),
|
|
||||||
kernelWidth: evaluationParams.samplingInputs.kernelWidth,
|
|
||||||
},
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|> E.O.bind(_, (r: RenderTypes.ShapeRenderer.Sampling.outputs) =>
|
|
||||||
r.shape
|
|
||||||
)
|
|
||||||
|> E.O.toResult("No response");
|
|
||||||
shape |> E.R.fmap(r => `Normalize(`RenderedDist(r)));
|
|
||||||
},
|
|
||||||
);
|
|
||||||
};
|
|
||||||
|
|
||||||
let operationToLeaf =
|
let operationToLeaf =
|
||||||
(
|
(
|
||||||
evaluationParams: evaluationParams,
|
evaluationParams: evaluationParams,
|
||||||
|
@ -107,7 +57,7 @@ module VerticalScaling = {
|
||||||
let fn = Operation.Scale.toFn(scaleOp);
|
let fn = Operation.Scale.toFn(scaleOp);
|
||||||
let integralSumCacheFn = Operation.Scale.toIntegralSumCacheFn(scaleOp);
|
let integralSumCacheFn = Operation.Scale.toIntegralSumCacheFn(scaleOp);
|
||||||
let integralCacheFn = Operation.Scale.toIntegralCacheFn(scaleOp);
|
let integralCacheFn = Operation.Scale.toIntegralCacheFn(scaleOp);
|
||||||
let renderedShape = render(evaluationParams, t);
|
let renderedShape = Render.render(evaluationParams, t);
|
||||||
|
|
||||||
switch (renderedShape, scaleBy) {
|
switch (renderedShape, scaleBy) {
|
||||||
| (Ok(`RenderedDist(rs)), `SymbolicDist(`Float(sm))) =>
|
| (Ok(`RenderedDist(rs)), `SymbolicDist(`Float(sm))) =>
|
||||||
|
@ -129,7 +79,7 @@ module VerticalScaling = {
|
||||||
|
|
||||||
module PointwiseCombination = {
|
module PointwiseCombination = {
|
||||||
let pointwiseAdd = (evaluationParams: evaluationParams, t1: t, t2: t) => {
|
let pointwiseAdd = (evaluationParams: evaluationParams, t1: t, t2: t) => {
|
||||||
switch (render(evaluationParams, t1), render(evaluationParams, t2)) {
|
switch (Render.render(evaluationParams, t1), Render.render(evaluationParams, t2)) {
|
||||||
| (Ok(`RenderedDist(rs1)), Ok(`RenderedDist(rs2))) =>
|
| (Ok(`RenderedDist(rs1)), Ok(`RenderedDist(rs2))) =>
|
||||||
Ok(
|
Ok(
|
||||||
`RenderedDist(
|
`RenderedDist(
|
||||||
|
@ -184,7 +134,7 @@ module Truncate = {
|
||||||
switch (leftCutoff, rightCutoff, t) {
|
switch (leftCutoff, rightCutoff, t) {
|
||||||
| (None, None, t) => `Solution(t)
|
| (None, None, t) => `Solution(t)
|
||||||
| (Some(lc), Some(rc), t) when lc > rc =>
|
| (Some(lc), Some(rc), t) when lc > rc =>
|
||||||
`Error("Left truncation bound must be smaller than right bound.")
|
`Error("Left truncation bound must be smaller than right truncation bound.")
|
||||||
| (lc, rc, `SymbolicDist(`Uniform(u))) =>
|
| (lc, rc, `SymbolicDist(`Uniform(u))) =>
|
||||||
`Solution(
|
`Solution(
|
||||||
`SymbolicDist(`Uniform(SymbolicDist.Uniform.truncate(lc, rc, u))),
|
`SymbolicDist(`Uniform(SymbolicDist.Uniform.truncate(lc, rc, u))),
|
||||||
|
@ -197,7 +147,7 @@ module Truncate = {
|
||||||
(evaluationParams: evaluationParams, leftCutoff, rightCutoff, t) => {
|
(evaluationParams: evaluationParams, leftCutoff, rightCutoff, t) => {
|
||||||
// TODO: use named args for xMin/xMax in renderToShape; if we're lucky we can at least get the tail
|
// TODO: use named args for xMin/xMax in renderToShape; if we're lucky we can at least get the tail
|
||||||
// of a distribution we otherwise wouldn't get at all
|
// of a distribution we otherwise wouldn't get at all
|
||||||
switch (render(evaluationParams, t)) {
|
switch (Render.ensureIsRendered(evaluationParams, t)) {
|
||||||
| Ok(`RenderedDist(rs)) =>
|
| Ok(`RenderedDist(rs)) =>
|
||||||
Ok(`RenderedDist(Shape.T.truncate(leftCutoff, rightCutoff, rs)))
|
Ok(`RenderedDist(Shape.T.truncate(leftCutoff, rightCutoff, rs)))
|
||||||
| Error(e) => Error(e)
|
| Error(e) => Error(e)
|
||||||
|
|
|
@ -37,61 +37,42 @@ module ExpressionTree = {
|
||||||
let evaluateNode = (evaluationParams: evaluationParams) =>
|
let evaluateNode = (evaluationParams: evaluationParams) =>
|
||||||
evaluationParams.evaluateNode(evaluationParams);
|
evaluationParams.evaluateNode(evaluationParams);
|
||||||
|
|
||||||
let render = (evaluationParams: evaluationParams, r) =>
|
|
||||||
evaluateNode(evaluationParams, `Render(r));
|
|
||||||
|
|
||||||
let renderable =
|
|
||||||
fun
|
|
||||||
| `SymbolicDist(_) => true
|
|
||||||
| `RenderedDist(_) => true
|
|
||||||
| _ => false;
|
|
||||||
|
|
||||||
let renderIfNotRenderable = (params, t) =>
|
|
||||||
!renderable(t)
|
|
||||||
? switch (render(params, t)) {
|
|
||||||
| Ok(r) => Ok(r)
|
|
||||||
| Error(e) => Error(e)
|
|
||||||
}
|
|
||||||
: Ok(t);
|
|
||||||
|
|
||||||
let renderIfNotRendered = (params, t) =>
|
|
||||||
switch (t) {
|
|
||||||
| `RenderedDist(_) => Ok(t)
|
|
||||||
| _ =>
|
|
||||||
switch (render(params, t)) {
|
|
||||||
| Ok(r) => Ok(r)
|
|
||||||
| Error(e) => Error(e)
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
let evaluateAndRetry = (evaluationParams, fn, node) =>
|
let evaluateAndRetry = (evaluationParams, fn, node) =>
|
||||||
node
|
node
|
||||||
|> evaluationParams.evaluateNode(evaluationParams)
|
|> evaluationParams.evaluateNode(evaluationParams)
|
||||||
|> E.R.bind(_, fn(evaluationParams));
|
|> E.R.bind(_, fn(evaluationParams));
|
||||||
|
|
||||||
let renderedShape = (item: node) =>
|
module Render = {
|
||||||
switch (item) {
|
type t = node;
|
||||||
| `RenderedDist(r) => Some(r)
|
|
||||||
| _ => None
|
|
||||||
};
|
|
||||||
|
|
||||||
let renderAndGetShape = (params, t) =>
|
let render = (evaluationParams: evaluationParams, r) =>
|
||||||
switch (renderIfNotRendered(params, t)) {
|
`Render(r) |> evaluateNode(evaluationParams);
|
||||||
| Ok(`RenderedDist(r)) => Ok(r)
|
|
||||||
| Error(r) =>
|
let ensureIsRendered = (params, t) =>
|
||||||
Js.log(r);
|
switch (t) {
|
||||||
Error(r);
|
| `RenderedDist(_) => Ok(t)
|
||||||
| Ok(l) =>
|
| _ =>
|
||||||
Js.log(l);
|
switch (render(params, t)) {
|
||||||
Error("Did not render as requested");
|
| Ok(`RenderedDist(r)) => Ok(`RenderedDist(r))
|
||||||
};
|
| Ok(_) => Error("Did not render as requested")
|
||||||
|
| Error(e) => Error(e)
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let ensureIsRenderedAndGetShape = (params, t) =>
|
||||||
|
switch (ensureIsRendered(params, t)) {
|
||||||
|
| Ok(`RenderedDist(r)) => Ok(r)
|
||||||
|
| Ok(_) => Error("Did not render as requested")
|
||||||
|
| Error(e) => Error(e)
|
||||||
|
};
|
||||||
|
|
||||||
|
let getShape = (item: node) =>
|
||||||
|
switch (item) {
|
||||||
|
| `RenderedDist(r) => Some(r)
|
||||||
|
| _ => None
|
||||||
|
};
|
||||||
|
};
|
||||||
|
|
||||||
let mapRenderable = (renderedFn, symFn, item: node) =>
|
|
||||||
switch (item) {
|
|
||||||
| `SymbolicDist(s) => Some(symFn(s))
|
|
||||||
| `RenderedDist(r) => Some(renderedFn(r))
|
|
||||||
| _ => None
|
|
||||||
};
|
|
||||||
};
|
};
|
||||||
|
|
||||||
type simplificationResult = [
|
type simplificationResult = [
|
||||||
|
|
80
src/distPlus/expressionTree/SamplingDistribution.re
Normal file
80
src/distPlus/expressionTree/SamplingDistribution.re
Normal file
|
@ -0,0 +1,80 @@
|
||||||
|
open ExpressionTypes.ExpressionTree;
|
||||||
|
|
||||||
|
let isSamplingDistribution: node => bool =
|
||||||
|
fun
|
||||||
|
| `SymbolicDist(_) => true
|
||||||
|
| `RenderedDist(_) => true
|
||||||
|
| _ => false;
|
||||||
|
|
||||||
|
let renderIfIsNotSamplingDistribution = (params, t) =>
|
||||||
|
!isSamplingDistribution(t)
|
||||||
|
? switch (Render.render(params, t)) {
|
||||||
|
| Ok(r) => Ok(r)
|
||||||
|
| Error(e) => Error(e)
|
||||||
|
}
|
||||||
|
: Ok(t);
|
||||||
|
|
||||||
|
let map = (~renderedDistFn, ~symbolicDistFn, node: node) =>
|
||||||
|
node
|
||||||
|
|> (
|
||||||
|
fun
|
||||||
|
| `RenderedDist(r) => Some(renderedDistFn(r))
|
||||||
|
| `SymbolicDist(s) => Some(symbolicDistFn(s))
|
||||||
|
| _ => None
|
||||||
|
);
|
||||||
|
|
||||||
|
let sampleN = n =>
|
||||||
|
map(
|
||||||
|
~renderedDistFn=Shape.sampleNRendered(n),
|
||||||
|
~symbolicDistFn=SymbolicDist.T.sampleN(n),
|
||||||
|
);
|
||||||
|
|
||||||
|
let getCombinationSamples = (n, algebraicOp, t1: node, t2: node) => {
|
||||||
|
switch (sampleN(n, t1), sampleN(n, t2)) {
|
||||||
|
| (Some(a), Some(b)) =>
|
||||||
|
Some(
|
||||||
|
Belt.Array.zip(a, b)
|
||||||
|
|> E.A.fmap(((a, b)) => Operation.Algebraic.toFn(algebraicOp, a, b)),
|
||||||
|
)
|
||||||
|
| _ => None
|
||||||
|
};
|
||||||
|
};
|
||||||
|
|
||||||
|
let combineShapesUsingSampling =
|
||||||
|
(evaluationParams: evaluationParams, algebraicOp, t1: node, t2: node) => {
|
||||||
|
let i1 = renderIfIsNotSamplingDistribution(evaluationParams, t1);
|
||||||
|
let i2 = renderIfIsNotSamplingDistribution(evaluationParams, t2);
|
||||||
|
E.R.merge(i1, i2)
|
||||||
|
|> E.R.bind(
|
||||||
|
_,
|
||||||
|
((a, b)) => {
|
||||||
|
let samples =
|
||||||
|
getCombinationSamples(
|
||||||
|
evaluationParams.samplingInputs.sampleCount,
|
||||||
|
algebraicOp,
|
||||||
|
a,
|
||||||
|
b,
|
||||||
|
);
|
||||||
|
|
||||||
|
// todo: This bottom part should probably be somewhere else.
|
||||||
|
let shape =
|
||||||
|
samples
|
||||||
|
|> E.O.fmap(
|
||||||
|
Samples.T.fromSamples(
|
||||||
|
~samplingInputs={
|
||||||
|
sampleCount:
|
||||||
|
Some(evaluationParams.samplingInputs.sampleCount),
|
||||||
|
outputXYPoints:
|
||||||
|
Some(evaluationParams.samplingInputs.outputXYPoints),
|
||||||
|
kernelWidth: evaluationParams.samplingInputs.kernelWidth,
|
||||||
|
},
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|> E.O.bind(_, (r: RenderTypes.ShapeRenderer.Sampling.outputs) =>
|
||||||
|
r.shape
|
||||||
|
)
|
||||||
|
|> E.O.toResult("No response");
|
||||||
|
shape |> E.R.fmap(r => `Normalize(`RenderedDist(r)));
|
||||||
|
},
|
||||||
|
);
|
||||||
|
};
|
Loading…
Reference in New Issue
Block a user