Fix conflicts

This commit is contained in:
Sebastian Kosch 2020-07-18 21:45:47 -07:00
commit d9cb164e5d
5 changed files with 181 additions and 109 deletions

View File

@ -110,7 +110,11 @@ let toDiscretePointMassesFromTriangulars =
}; };
let combineShapesContinuousContinuous = let combineShapesContinuousContinuous =
(op: ExpressionTypes.algebraicOperation, s1: DistTypes.xyShape, s2: DistTypes.xyShape) (
op: ExpressionTypes.algebraicOperation,
s1: DistTypes.xyShape,
s2: DistTypes.xyShape,
)
: DistTypes.xyShape => { : DistTypes.xyShape => {
let t1n = s1 |> XYShape.T.length; let t1n = s1 |> XYShape.T.length;
let t2n = s2 |> XYShape.T.length; let t2n = s2 |> XYShape.T.length;
@ -118,7 +122,8 @@ let combineShapesContinuousContinuous =
// if we add the two distributions, we should probably use normal filters. // if we add the two distributions, we should probably use normal filters.
// if we multiply the two distributions, we should probably use lognormal filters. // if we multiply the two distributions, we should probably use lognormal filters.
let t1m = toDiscretePointMassesFromTriangulars(s1); let t1m = toDiscretePointMassesFromTriangulars(s1);
let t2m = switch (op) { let t2m =
switch (op) {
| `Divide => toDiscretePointMassesFromTriangulars(~inverse=true, s2) | `Divide => toDiscretePointMassesFromTriangulars(~inverse=true, s2)
| _ => toDiscretePointMassesFromTriangulars(~inverse=false, s2) | _ => toDiscretePointMassesFromTriangulars(~inverse=false, s2)
}; };
@ -185,27 +190,30 @@ let combineShapesContinuousContinuous =
// we now want to create a set of target points. For now, let's just evenly distribute 200 points between // we now want to create a set of target points. For now, let's just evenly distribute 200 points between
// between the outputMinX and outputMaxX // between the outputMinX and outputMaxX
let nOut = 300; let nOut = 300;
let outputXs: array(float) = E.A.Floats.range(outputMinX^, outputMaxX^, nOut); let outputXs: array(float) =
E.A.Floats.range(outputMinX^, outputMaxX^, nOut);
let outputYs: array(float) = Belt.Array.make(nOut, 0.0); let outputYs: array(float) = Belt.Array.make(nOut, 0.0);
// now, for each of the outputYs, accumulate from a Gaussian kernel over each input point. // now, for each of the outputYs, accumulate from a Gaussian kernel over each input point.
for (j in 0 to E.A.length(masses) - 1) { // go through all of the result points for (j in 0 to E.A.length(masses) - 1) {
// go through all of the result points
if (variances[j] > 0. && masses[j] > 0.) { if (variances[j] > 0. && masses[j] > 0.) {
for (i in 0 to E.A.length(outputXs) - 1) { // go through all of the target points for (i in 0 to E.A.length(outputXs) - 1) {
// go through all of the target points
let dx = outputXs[i] -. means[j]; let dx = outputXs[i] -. means[j];
let contribution = masses[j] *. exp(-. (dx ** 2.) /. (2. *. variances[j])) /. (sqrt(2. *. 3.14159276 *. variances[j])); let contribution =
masses[j]
*. exp(-. (dx ** 2.) /. (2. *. variances[j]))
/. sqrt(2. *. 3.14159276 *. variances[j]);
Belt.Array.set(outputYs, i, outputYs[i] +. contribution) |> ignore; Belt.Array.set(outputYs, i, outputYs[i] +. contribution) |> ignore;
();
}; };
(); };
} |> ignore;
();
}; };
{xs: outputXs, ys: outputYs}; {xs: outputXs, ys: outputYs};
}; };
let toDiscretePointMassesFromDiscrete = (s: DistTypes.xyShape): pointMassesWithMoments => { let toDiscretePointMassesFromDiscrete =
let n = s |> XYShape.T.length; (s: DistTypes.xyShape): pointMassesWithMoments => {
let {xs, ys}: XYShape.T.t = s; let {xs, ys}: XYShape.T.t = s;
let n = E.A.length(xs); let n = E.A.length(xs);
@ -231,7 +239,7 @@ let combineShapesContinuousDiscrete =
switch (op) { switch (op) {
| `Add | `Add
| `Subtract => { | `Subtract =>
for (j in 0 to t2n - 1) { for (j in 0 to t2n - 1) {
// creates a new continuous shape for each one of the discrete points, and collects them in outXYShapes. // creates a new continuous shape for each one of the discrete points, and collects them in outXYShapes.
let dxyShape: array((float, float)) = let dxyShape: array((float, float)) =
@ -240,16 +248,16 @@ let combineShapesContinuousDiscrete =
Belt.Array.set( Belt.Array.set(
dxyShape, dxyShape,
i, i,
(fn(continuousShape.xs[i], discreteShape.xs[j]), continuousShape.ys[i] *. discreteShape.ys[j]), (fn(continuousShape.xs[i], discreteShape.xs[j]),
continuousShape.ys[i] *. discreteShape.ys[j]),
) |> ignore; ) |> ignore;
(); ();
}; };
Belt.Array.set(outXYShapes, j, dxyShape) |> ignore; Belt.Array.set(outXYShapes, j, dxyShape) |> ignore;
(); ();
} }
}
| `Multiply | `Multiply
| `Divide => { | `Divide =>
for (j in 0 to t2n - 1) { for (j in 0 to t2n - 1) {
// creates a new continuous shape for each one of the discrete points, and collects them in outXYShapes. // creates a new continuous shape for each one of the discrete points, and collects them in outXYShapes.
let dxyShape: array((float, float)) = let dxyShape: array((float, float)) =
@ -265,7 +273,6 @@ let combineShapesContinuousDiscrete =
Belt.Array.set(outXYShapes, j, dxyShape) |> ignore; Belt.Array.set(outXYShapes, j, dxyShape) |> ignore;
(); ();
} }
}
}; };
outXYShapes outXYShapes

View File

@ -27,24 +27,19 @@ let combineAlgebraically =
(op: ExpressionTypes.algebraicOperation, t1: t, t2: t): t => { (op: ExpressionTypes.algebraicOperation, t1: t, t2: t): t => {
switch (t1, t2) { switch (t1, t2) {
| (Continuous(m1), Continuous(m2)) => | (Continuous(m1), Continuous(m2)) =>
DistTypes.Continuous( Continuous.combineAlgebraically(op, m1, m2) |> Continuous.T.toShape;
Continuous.combineAlgebraically(op, m1, m2),
)
| (Continuous(m1), Discrete(m2)) | (Continuous(m1), Discrete(m2))
| (Discrete(m2), Continuous(m1)) => | (Discrete(m2), Continuous(m1)) =>
DistTypes.Continuous( Continuous.combineAlgebraicallyWithDiscrete(op, m1, m2) |> Continuous.T.toShape
Continuous.combineAlgebraicallyWithDiscrete(op, m1, m2),
)
| (Discrete(m1), Discrete(m2)) => | (Discrete(m1), Discrete(m2)) =>
DistTypes.Discrete(Discrete.combineAlgebraically(op, m1, m2)) Discrete.combineAlgebraically(op, m1, m2) |> Discrete.T.toShape
| (m1, m2) => | (m1, m2) =>
DistTypes.Mixed(
Mixed.combineAlgebraically( Mixed.combineAlgebraically(
op, op,
toMixed(m1), toMixed(m1),
toMixed(m2), toMixed(m2),
),
) )
|> Mixed.T.toShape
}; };
}; };

View File

@ -20,61 +20,15 @@ module AlgebraicCombination = {
| _ => Ok(`AlgebraicCombination((operation, t1, t2))) | _ => Ok(`AlgebraicCombination((operation, t1, t2)))
}; };
let tryCombination = (n, algebraicOp, t1: node, t2: node) => { let combinationByRendering =
let sampleN = (evaluationParams, algebraicOp, t1: node, t2: node)
mapRenderable(Shape.sampleNRendered(n), SymbolicDist.T.sampleN(n)); : result(node, string) => {
switch (sampleN(t1), sampleN(t2)) { E.R.merge(
| (Some(a), Some(b)) => Render.ensureIsRenderedAndGetShape(evaluationParams, t1),
Some( Render.ensureIsRenderedAndGetShape(evaluationParams, t2),
Belt.Array.zip(a, b)
|> E.A.fmap(((a, b)) => Operation.Algebraic.toFn(algebraicOp, a, b)),
) )
| _ => None |> E.R.fmap(((a, b)) =>
}; `RenderedDist(Shape.combineAlgebraically(algebraicOp, a, b))
};
let renderIfNotRendered = (params, t) =>
!renderable(t)
? switch (render(params, t)) {
| Ok(r) => Ok(r)
| Error(e) => Error(e)
}
: Ok(t);
let combineAsShapes =
(evaluationParams: evaluationParams, algebraicOp, t1: node, t2: node) => {
let i1 = renderIfNotRendered(evaluationParams, t1);
let i2 = renderIfNotRendered(evaluationParams, t2);
E.R.merge(i1, i2)
|> E.R.bind(
_,
((a, b)) => {
let samples =
tryCombination(
evaluationParams.samplingInputs.sampleCount,
algebraicOp,
a,
b,
);
let shape =
samples
|> E.O.fmap(
Samples.T.fromSamples(
~samplingInputs={
sampleCount:
Some(evaluationParams.samplingInputs.sampleCount),
outputXYPoints:
Some(evaluationParams.samplingInputs.outputXYPoints),
kernelWidth: evaluationParams.samplingInputs.kernelWidth,
},
),
)
|> E.O.bind(_, (r: RenderTypes.ShapeRenderer.Sampling.outputs) =>
r.shape
)
|> E.O.toResult("No response");
shape |> E.R.fmap(r => `Normalize(`RenderedDist(r)));
},
); );
}; };
@ -92,7 +46,7 @@ module AlgebraicCombination = {
_, _,
fun fun
| `SymbolicDist(d) as t => Ok(t) | `SymbolicDist(d) as t => Ok(t)
| _ => combineAsShapes(evaluationParams, algebraicOp, t1, t2), | _ => combinationByRendering(evaluationParams, algebraicOp, t1, t2),
); );
}; };
@ -103,7 +57,7 @@ module VerticalScaling = {
let fn = Operation.Scale.toFn(scaleOp); let fn = Operation.Scale.toFn(scaleOp);
let integralSumCacheFn = Operation.Scale.toIntegralSumCacheFn(scaleOp); let integralSumCacheFn = Operation.Scale.toIntegralSumCacheFn(scaleOp);
let integralCacheFn = Operation.Scale.toIntegralCacheFn(scaleOp); let integralCacheFn = Operation.Scale.toIntegralCacheFn(scaleOp);
let renderedShape = render(evaluationParams, t); let renderedShape = Render.render(evaluationParams, t);
switch (renderedShape, scaleBy) { switch (renderedShape, scaleBy) {
| (Ok(`RenderedDist(rs)), `SymbolicDist(`Float(sm))) => | (Ok(`RenderedDist(rs)), `SymbolicDist(`Float(sm))) =>
@ -125,13 +79,22 @@ module VerticalScaling = {
module PointwiseCombination = { module PointwiseCombination = {
let pointwiseAdd = (evaluationParams: evaluationParams, t1: t, t2: t) => { let pointwiseAdd = (evaluationParams: evaluationParams, t1: t, t2: t) => {
switch (render(evaluationParams, t1), render(evaluationParams, t2)) { switch (Render.render(evaluationParams, t1), Render.render(evaluationParams, t2)) {
| (Ok(`RenderedDist(rs1)), Ok(`RenderedDist(rs2))) => | (Ok(`RenderedDist(rs1)), Ok(`RenderedDist(rs2))) =>
Ok( Ok(
`RenderedDist( `RenderedDist(
Shape.combinePointwise( Shape.combinePointwise(
~integralSumCachesFn=(a, b) => Some(a +. b), ~integralSumCachesFn=(a, b) => Some(a +. b),
~integralCachesFn=(a, b) => Some(Continuous.combinePointwise(~distributionType=`CDF, (+.), a, b)), ~integralCachesFn=
(a, b) =>
Some(
Continuous.combinePointwise(
~distributionType=`CDF,
(+.),
a,
b,
),
),
(+.), (+.),
rs1, rs1,
rs2, rs2,
@ -153,7 +116,12 @@ module PointwiseCombination = {
}; };
let operationToLeaf = let operationToLeaf =
(evaluationParams: evaluationParams, pointwiseOp: pointwiseOperation, t1: t, t2: t) => { (
evaluationParams: evaluationParams,
pointwiseOp: pointwiseOperation,
t1: t,
t2: t,
) => {
switch (pointwiseOp) { switch (pointwiseOp) {
| `Add => pointwiseAdd(evaluationParams, t1, t2) | `Add => pointwiseAdd(evaluationParams, t1, t2)
| `Multiply => pointwiseMultiply(evaluationParams, t1, t2) | `Multiply => pointwiseMultiply(evaluationParams, t1, t2)
@ -166,7 +134,7 @@ module Truncate = {
switch (leftCutoff, rightCutoff, t) { switch (leftCutoff, rightCutoff, t) {
| (None, None, t) => `Solution(t) | (None, None, t) => `Solution(t)
| (Some(lc), Some(rc), t) when lc > rc => | (Some(lc), Some(rc), t) when lc > rc =>
`Error("Left truncation bound must be smaller than right bound.") `Error("Left truncation bound must be smaller than right truncation bound.")
| (lc, rc, `SymbolicDist(`Uniform(u))) => | (lc, rc, `SymbolicDist(`Uniform(u))) =>
`Solution( `Solution(
`SymbolicDist(`Uniform(SymbolicDist.Uniform.truncate(lc, rc, u))), `SymbolicDist(`Uniform(SymbolicDist.Uniform.truncate(lc, rc, u))),
@ -179,7 +147,7 @@ module Truncate = {
(evaluationParams: evaluationParams, leftCutoff, rightCutoff, t) => { (evaluationParams: evaluationParams, leftCutoff, rightCutoff, t) => {
// TODO: use named args for xMin/xMax in renderToShape; if we're lucky we can at least get the tail // TODO: use named args for xMin/xMax in renderToShape; if we're lucky we can at least get the tail
// of a distribution we otherwise wouldn't get at all // of a distribution we otherwise wouldn't get at all
switch (render(evaluationParams, t)) { switch (Render.ensureIsRendered(evaluationParams, t)) {
| Ok(`RenderedDist(rs)) => | Ok(`RenderedDist(rs)) =>
Ok(`RenderedDist(Shape.T.truncate(leftCutoff, rightCutoff, rs))) Ok(`RenderedDist(Shape.T.truncate(leftCutoff, rightCutoff, rs)))
| Error(e) => Error(e) | Error(e) => Error(e)

View File

@ -1,7 +1,13 @@
type algebraicOperation = [ | `Add | `Multiply | `Subtract | `Divide]; type algebraicOperation = [ | `Add | `Multiply | `Subtract | `Divide];
type pointwiseOperation = [ | `Add | `Multiply]; type pointwiseOperation = [ | `Add | `Multiply];
type scaleOperation = [ | `Multiply | `Exponentiate | `Log]; type scaleOperation = [ | `Multiply | `Exponentiate | `Log];
type distToFloatOperation = [ | `Pdf(float) | `Cdf(float) | `Inv(float) | `Mean | `Sample]; type distToFloatOperation = [
| `Pdf(float)
| `Cdf(float)
| `Inv(float)
| `Mean
| `Sample
];
module ExpressionTree = { module ExpressionTree = {
type node = [ type node = [
@ -31,28 +37,44 @@ module ExpressionTree = {
let evaluateNode = (evaluationParams: evaluationParams) => let evaluateNode = (evaluationParams: evaluationParams) =>
evaluationParams.evaluateNode(evaluationParams); evaluationParams.evaluateNode(evaluationParams);
let render = (evaluationParams: evaluationParams, r) =>
evaluateNode(evaluationParams, `Render(r));
let evaluateAndRetry = (evaluationParams, fn, node) => let evaluateAndRetry = (evaluationParams, fn, node) =>
node node
|> evaluationParams.evaluateNode(evaluationParams) |> evaluationParams.evaluateNode(evaluationParams)
|> E.R.bind(_, fn(evaluationParams)); |> E.R.bind(_, fn(evaluationParams));
let renderable = module Render = {
fun type t = node;
| `SymbolicDist(_) => true
| `RenderedDist(_) => true
| _ => false;
let mapRenderable = (renderedFn, symFn, item: node) => let render = (evaluationParams: evaluationParams, r) =>
`Render(r) |> evaluateNode(evaluationParams);
let ensureIsRendered = (params, t) =>
switch (t) {
| `RenderedDist(_) => Ok(t)
| _ =>
switch (render(params, t)) {
| Ok(`RenderedDist(r)) => Ok(`RenderedDist(r))
| Ok(_) => Error("Did not render as requested")
| Error(e) => Error(e)
}
};
let ensureIsRenderedAndGetShape = (params, t) =>
switch (ensureIsRendered(params, t)) {
| Ok(`RenderedDist(r)) => Ok(r)
| Ok(_) => Error("Did not render as requested")
| Error(e) => Error(e)
};
let getShape = (item: node) =>
switch (item) { switch (item) {
| `SymbolicDist(s) => Some(symFn(s)) | `RenderedDist(r) => Some(r)
| `RenderedDist(r) => Some(renderedFn(r))
| _ => None | _ => None
}; };
}; };
};
type simplificationResult = [ type simplificationResult = [
| `Solution(ExpressionTree.node) | `Solution(ExpressionTree.node)
| `Error(string) | `Error(string)

View File

@ -0,0 +1,80 @@
open ExpressionTypes.ExpressionTree;
let isSamplingDistribution: node => bool =
fun
| `SymbolicDist(_) => true
| `RenderedDist(_) => true
| _ => false;
let renderIfIsNotSamplingDistribution = (params, t) =>
!isSamplingDistribution(t)
? switch (Render.render(params, t)) {
| Ok(r) => Ok(r)
| Error(e) => Error(e)
}
: Ok(t);
let map = (~renderedDistFn, ~symbolicDistFn, node: node) =>
node
|> (
fun
| `RenderedDist(r) => Some(renderedDistFn(r))
| `SymbolicDist(s) => Some(symbolicDistFn(s))
| _ => None
);
let sampleN = n =>
map(
~renderedDistFn=Shape.sampleNRendered(n),
~symbolicDistFn=SymbolicDist.T.sampleN(n),
);
let getCombinationSamples = (n, algebraicOp, t1: node, t2: node) => {
switch (sampleN(n, t1), sampleN(n, t2)) {
| (Some(a), Some(b)) =>
Some(
Belt.Array.zip(a, b)
|> E.A.fmap(((a, b)) => Operation.Algebraic.toFn(algebraicOp, a, b)),
)
| _ => None
};
};
let combineShapesUsingSampling =
(evaluationParams: evaluationParams, algebraicOp, t1: node, t2: node) => {
let i1 = renderIfIsNotSamplingDistribution(evaluationParams, t1);
let i2 = renderIfIsNotSamplingDistribution(evaluationParams, t2);
E.R.merge(i1, i2)
|> E.R.bind(
_,
((a, b)) => {
let samples =
getCombinationSamples(
evaluationParams.samplingInputs.sampleCount,
algebraicOp,
a,
b,
);
// todo: This bottom part should probably be somewhere else.
let shape =
samples
|> E.O.fmap(
Samples.T.fromSamples(
~samplingInputs={
sampleCount:
Some(evaluationParams.samplingInputs.sampleCount),
outputXYPoints:
Some(evaluationParams.samplingInputs.outputXYPoints),
kernelWidth: evaluationParams.samplingInputs.kernelWidth,
},
),
)
|> E.O.bind(_, (r: RenderTypes.ShapeRenderer.Sampling.outputs) =>
r.shape
)
|> E.O.toResult("No response");
shape |> E.R.fmap(r => `Normalize(`RenderedDist(r)));
},
);
};