Moved rendering code from TreeNode to SymbolicDist

This commit is contained in:
Ozzie Gooen 2020-07-02 14:30:01 +01:00
parent 99c0803953
commit 96df9ced85
2 changed files with 43 additions and 73 deletions

View File

@ -300,4 +300,18 @@ module T = {
|> E.O.dimap(r => `AnalyticalSolution(r), () => `NoSolution) |> E.O.dimap(r => `AnalyticalSolution(r), () => `NoSolution)
| _ => `NoSolution | _ => `NoSolution
}; };
let toShape = (sampleCount, d: symbolicDist): DistTypes.shape =>
switch (d) {
| `Float(v) =>
Discrete(
Distributions.Discrete.make({xs: [|v|], ys: [|1.0|]}, Some(1.0)),
)
| _ =>
let xs = interpolateXs(~xSelection=`ByWeight, d, sampleCount);
let ys = xs |> E.A.fmap(x => pdf(x, d));
Continuous(
Distributions.Continuous.make(`Linear, {xs, ys}, Some(1.0)),
);
};
}; };

View File

@ -102,16 +102,15 @@ module TreeNode = {
}; };
}; };
let evaluateToLeaf = let toLeaf =
( (
algebraicOp: SymbolicTypes.algebraicOperation,
operationToLeaf, operationToLeaf,
algebraicOp: SymbolicTypes.algebraicOperation,
t1: t, t1: t,
t2: t, t2: t,
) )
: result(treeNode, string) => : result(treeNode, string) =>
algebraicOp toTreeNode(algebraicOp, t1, t2)
|> toTreeNode(_, t1, t2)
|> tryAnalyticalSolution |> tryAnalyticalSolution
|> E.R.bind( |> E.R.bind(
_, _,
@ -124,7 +123,7 @@ module TreeNode = {
}; };
module VerticalScaling = { module VerticalScaling = {
let evaluateToLeaf = (scaleOp, operationToLeaf, t, scaleBy) => { let toLeaf = (operationToLeaf,scaleOp, t, scaleBy) => {
// scaleBy has to be a single float, otherwise we'll return an error. // scaleBy has to be a single float, otherwise we'll return an error.
let fn = SymbolicTypes.Scale.toFn(scaleOp); let fn = SymbolicTypes.Scale.toFn(scaleOp);
let knownIntegralSumFn = let knownIntegralSumFn =
@ -183,7 +182,7 @@ module TreeNode = {
); );
}; };
let evaluateToLeaf = (pointwiseOp, operationToLeaf, t1, t2) => { let toLeaf = (operationToLeaf,pointwiseOp, t1, t2) => {
switch (pointwiseOp) { switch (pointwiseOp) {
| `Add => pointwiseAdd(operationToLeaf, t1, t2) | `Add => pointwiseAdd(operationToLeaf, t1, t2)
| `Multiply => pointwiseMultiply(operationToLeaf, t1, t2) | `Multiply => pointwiseMultiply(operationToLeaf, t1, t2)
@ -235,11 +234,11 @@ module TreeNode = {
}; };
}; };
let evaluateToLeaf = let toLeaf =
( (
operationToLeaf,
leftCutoff: option(float), leftCutoff: option(float),
rightCutoff: option(float), rightCutoff: option(float),
operationToLeaf,
t: treeNode, t: treeNode,
) )
: result(treeNode, string) => { : result(treeNode, string) => {
@ -256,15 +255,13 @@ module TreeNode = {
}; };
module Normalize = { module Normalize = {
let rec evaluateToLeaf = let rec toLeaf = (operationToLeaf, t: treeNode): result(treeNode, string) => {
(operationToLeaf, t: treeNode): result(treeNode, string) => {
switch (t) { switch (t) {
| `Leaf(`SymbolicDist(_)) => Ok(t)
| `Leaf(`RenderedDist(s)) => | `Leaf(`RenderedDist(s)) =>
let normalized = Distributions.Shape.T.normalize(s); Ok(`Leaf(`RenderedDist(Distributions.Shape.T.normalize(s))))
Ok(`Leaf(`RenderedDist(normalized))); | `Leaf(`SymbolicDist(_)) => Ok(t)
| `Operation(op) => | `Operation(op) =>
E.R.bind(operationToLeaf(op), evaluateToLeaf(operationToLeaf)) operationToLeaf(op) |> E.R.bind(_, toLeaf(operationToLeaf))
}; };
}; };
}; };
@ -280,10 +277,10 @@ module TreeNode = {
Distributions.Shape.operate(distToFloatOp, rs) Distributions.Shape.operate(distToFloatOp, rs)
|> (v => Ok(`Leaf(`SymbolicDist(`Float(v))))); |> (v => Ok(`Leaf(`SymbolicDist(`Float(v)))));
}; };
let rec evaluateToLeaf = let rec toLeaf =
( (
distToFloatOp: distToFloatOperation,
operationToLeaf, operationToLeaf,
distToFloatOp: distToFloatOperation,
t: treeNode, t: treeNode,
) )
: result(treeNode, string) => { : result(treeNode, string) => {
@ -294,14 +291,14 @@ module TreeNode = {
| `Operation(op) => | `Operation(op) =>
E.R.bind( E.R.bind(
operationToLeaf(op), operationToLeaf(op),
evaluateToLeaf(distToFloatOp, operationToLeaf), toLeaf(operationToLeaf,distToFloatOp),
) )
}; };
}; };
}; };
module Render = { module Render = {
let rec evaluateToRenderedDist = let rec toLeaf =
( (
operationToLeaf: operation => result(t, string), operationToLeaf: operation => result(t, string),
sampleCount: int, sampleCount: int,
@ -309,49 +306,13 @@ module TreeNode = {
) )
: result(t, string) => { : result(t, string) => {
switch (t) { switch (t) {
| `Leaf(`RenderedDist(s)) => Ok(`Leaf(`RenderedDist(s))) // already a rendered shape, we're done here
| `Leaf(`SymbolicDist(d)) => | `Leaf(`SymbolicDist(d)) =>
// todo: move to dist Ok(`Leaf(`RenderedDist(SymbolicDist.T.toShape(sampleCount, d))))
switch (d) { | `Leaf(`RenderedDist(_)) as t => Ok(t) // already a rendered shape, we're done here
| `Float(v) =>
Ok(
`Leaf(
`RenderedDist(
Discrete(
Distributions.Discrete.make(
{xs: [|v|], ys: [|1.0|]},
Some(1.0),
),
),
),
),
)
| _ =>
let xs =
SymbolicDist.T.interpolateXs(
~xSelection=`ByWeight,
d,
sampleCount,
);
let ys = xs |> E.A.fmap(x => SymbolicDist.T.pdf(x, d));
Ok(
`Leaf(
`RenderedDist(
Continuous(
Distributions.Continuous.make(
`Linear,
{xs, ys},
Some(1.0),
),
),
),
),
);
}
| `Operation(op) => | `Operation(op) =>
E.R.bind( E.R.bind(
operationToLeaf(op), operationToLeaf(op),
evaluateToRenderedDist(operationToLeaf, sampleCount), toLeaf(operationToLeaf, sampleCount),
) )
}; };
}; };
@ -363,43 +324,38 @@ module TreeNode = {
// have a way to call this function on their children, if their children are themselves Operation nodes. // have a way to call this function on their children, if their children are themselves Operation nodes.
switch (op) { switch (op) {
| `AlgebraicCombination(algebraicOp, t1, t2) => | `AlgebraicCombination(algebraicOp, t1, t2) =>
AlgebraicCombination.evaluateToLeaf( AlgebraicCombination.toLeaf(
algebraicOp,
operationToLeaf(sampleCount), operationToLeaf(sampleCount),
algebraicOp,
t1, t1,
t2 // we want to give it the option to render or simply leave it as is t2 // we want to give it the option to render or simply leave it as is
) )
| `PointwiseCombination(pointwiseOp, t1, t2) => | `PointwiseCombination(pointwiseOp, t1, t2) =>
PointwiseCombination.evaluateToLeaf( PointwiseCombination.toLeaf(
pointwiseOp,
operationToLeaf(sampleCount), operationToLeaf(sampleCount),
pointwiseOp,
t1, t1,
t2, t2,
) )
| `VerticalScaling(scaleOp, t, scaleBy) => | `VerticalScaling(scaleOp, t, scaleBy) =>
VerticalScaling.evaluateToLeaf( VerticalScaling.toLeaf(
scaleOp,
operationToLeaf(sampleCount), operationToLeaf(sampleCount),
scaleOp,
t, t,
scaleBy, scaleBy,
) )
| `Truncate(leftCutoff, rightCutoff, t) => | `Truncate(leftCutoff, rightCutoff, t) =>
Truncate.evaluateToLeaf( Truncate.toLeaf(
operationToLeaf(sampleCount),
leftCutoff, leftCutoff,
rightCutoff, rightCutoff,
operationToLeaf(sampleCount),
t, t,
) )
| `FloatFromDist(distToFloatOp, t) => | `FloatFromDist(distToFloatOp, t) =>
FloatFromDist.evaluateToLeaf( FloatFromDist.toLeaf(operationToLeaf(sampleCount),distToFloatOp, t)
distToFloatOp, | `Normalize(t) => Normalize.toLeaf(operationToLeaf(sampleCount), t)
operationToLeaf(sampleCount),
t,
)
| `Normalize(t) =>
Normalize.evaluateToLeaf(operationToLeaf(sampleCount), t)
| `Render(t) => | `Render(t) =>
Render.evaluateToRenderedDist( Render.toLeaf(
operationToLeaf(sampleCount), operationToLeaf(sampleCount),
sampleCount, sampleCount,
t, t,