Fix conflicts

This commit is contained in:
Sebastian Kosch 2020-07-18 21:45:47 -07:00
commit d9cb164e5d
5 changed files with 181 additions and 109 deletions

View File

@ -110,7 +110,11 @@ let toDiscretePointMassesFromTriangulars =
};
let combineShapesContinuousContinuous =
(op: ExpressionTypes.algebraicOperation, s1: DistTypes.xyShape, s2: DistTypes.xyShape)
(
op: ExpressionTypes.algebraicOperation,
s1: DistTypes.xyShape,
s2: DistTypes.xyShape,
)
: DistTypes.xyShape => {
let t1n = s1 |> XYShape.T.length;
let t2n = s2 |> XYShape.T.length;
@ -118,10 +122,11 @@ let combineShapesContinuousContinuous =
// if we add the two distributions, we should probably use normal filters.
// if we multiply the two distributions, we should probably use lognormal filters.
let t1m = toDiscretePointMassesFromTriangulars(s1);
let t2m = switch (op) {
let t2m =
switch (op) {
| `Divide => toDiscretePointMassesFromTriangulars(~inverse=true, s2)
| _ => toDiscretePointMassesFromTriangulars(~inverse=false, s2)
};
};
let combineMeansFn =
switch (op) {
@ -185,27 +190,30 @@ let combineShapesContinuousContinuous =
// we now want to create a set of target points. For now, let's just evenly distribute 200 points between
// between the outputMinX and outputMaxX
let nOut = 300;
let outputXs: array(float) = E.A.Floats.range(outputMinX^, outputMaxX^, nOut);
let outputXs: array(float) =
E.A.Floats.range(outputMinX^, outputMaxX^, nOut);
let outputYs: array(float) = Belt.Array.make(nOut, 0.0);
// now, for each of the outputYs, accumulate from a Gaussian kernel over each input point.
for (j in 0 to E.A.length(masses) - 1) { // go through all of the result points
for (j in 0 to E.A.length(masses) - 1) {
// go through all of the result points
if (variances[j] > 0. && masses[j] > 0.) {
for (i in 0 to E.A.length(outputXs) - 1) { // go through all of the target points
for (i in 0 to E.A.length(outputXs) - 1) {
// go through all of the target points
let dx = outputXs[i] -. means[j];
let contribution = masses[j] *. exp(-. (dx ** 2.) /. (2. *. variances[j])) /. (sqrt(2. *. 3.14159276 *. variances[j]));
let contribution =
masses[j]
*. exp(-. (dx ** 2.) /. (2. *. variances[j]))
/. sqrt(2. *. 3.14159276 *. variances[j]);
Belt.Array.set(outputYs, i, outputYs[i] +. contribution) |> ignore;
();
};
();
} |> ignore;
();
};
};
{xs: outputXs, ys: outputYs};
};
let toDiscretePointMassesFromDiscrete = (s: DistTypes.xyShape): pointMassesWithMoments => {
let n = s |> XYShape.T.length;
let toDiscretePointMassesFromDiscrete =
(s: DistTypes.xyShape): pointMassesWithMoments => {
let {xs, ys}: XYShape.T.t = s;
let n = E.A.length(xs);
@ -231,7 +239,7 @@ let combineShapesContinuousDiscrete =
switch (op) {
| `Add
| `Subtract => {
| `Subtract =>
for (j in 0 to t2n - 1) {
// creates a new continuous shape for each one of the discrete points, and collects them in outXYShapes.
let dxyShape: array((float, float)) =
@ -240,16 +248,16 @@ let combineShapesContinuousDiscrete =
Belt.Array.set(
dxyShape,
i,
(fn(continuousShape.xs[i], discreteShape.xs[j]), continuousShape.ys[i] *. discreteShape.ys[j]),
(fn(continuousShape.xs[i], discreteShape.xs[j]),
continuousShape.ys[i] *. discreteShape.ys[j]),
) |> ignore;
();
};
Belt.Array.set(outXYShapes, j, dxyShape) |> ignore;
();
}
}
| `Multiply
| `Divide => {
| `Multiply
| `Divide =>
for (j in 0 to t2n - 1) {
// creates a new continuous shape for each one of the discrete points, and collects them in outXYShapes.
let dxyShape: array((float, float)) =
@ -265,7 +273,6 @@ let combineShapesContinuousDiscrete =
Belt.Array.set(outXYShapes, j, dxyShape) |> ignore;
();
}
}
};
outXYShapes

View File

@ -27,24 +27,19 @@ let combineAlgebraically =
(op: ExpressionTypes.algebraicOperation, t1: t, t2: t): t => {
switch (t1, t2) {
| (Continuous(m1), Continuous(m2)) =>
DistTypes.Continuous(
Continuous.combineAlgebraically(op, m1, m2),
)
Continuous.combineAlgebraically(op, m1, m2) |> Continuous.T.toShape;
| (Continuous(m1), Discrete(m2))
| (Discrete(m2), Continuous(m1)) =>
DistTypes.Continuous(
Continuous.combineAlgebraicallyWithDiscrete(op, m1, m2),
)
Continuous.combineAlgebraicallyWithDiscrete(op, m1, m2) |> Continuous.T.toShape
| (Discrete(m1), Discrete(m2)) =>
DistTypes.Discrete(Discrete.combineAlgebraically(op, m1, m2))
Discrete.combineAlgebraically(op, m1, m2) |> Discrete.T.toShape
| (m1, m2) =>
DistTypes.Mixed(
Mixed.combineAlgebraically(
op,
toMixed(m1),
toMixed(m2),
),
Mixed.combineAlgebraically(
op,
toMixed(m1),
toMixed(m2),
)
|> Mixed.T.toShape
};
};

View File

@ -20,61 +20,15 @@ module AlgebraicCombination = {
| _ => Ok(`AlgebraicCombination((operation, t1, t2)))
};
let tryCombination = (n, algebraicOp, t1: node, t2: node) => {
let sampleN =
mapRenderable(Shape.sampleNRendered(n), SymbolicDist.T.sampleN(n));
switch (sampleN(t1), sampleN(t2)) {
| (Some(a), Some(b)) =>
Some(
Belt.Array.zip(a, b)
|> E.A.fmap(((a, b)) => Operation.Algebraic.toFn(algebraicOp, a, b)),
)
| _ => None
};
};
let renderIfNotRendered = (params, t) =>
!renderable(t)
? switch (render(params, t)) {
| Ok(r) => Ok(r)
| Error(e) => Error(e)
}
: Ok(t);
let combineAsShapes =
(evaluationParams: evaluationParams, algebraicOp, t1: node, t2: node) => {
let i1 = renderIfNotRendered(evaluationParams, t1);
let i2 = renderIfNotRendered(evaluationParams, t2);
E.R.merge(i1, i2)
|> E.R.bind(
_,
((a, b)) => {
let samples =
tryCombination(
evaluationParams.samplingInputs.sampleCount,
algebraicOp,
a,
b,
);
let shape =
samples
|> E.O.fmap(
Samples.T.fromSamples(
~samplingInputs={
sampleCount:
Some(evaluationParams.samplingInputs.sampleCount),
outputXYPoints:
Some(evaluationParams.samplingInputs.outputXYPoints),
kernelWidth: evaluationParams.samplingInputs.kernelWidth,
},
),
)
|> E.O.bind(_, (r: RenderTypes.ShapeRenderer.Sampling.outputs) =>
r.shape
)
|> E.O.toResult("No response");
shape |> E.R.fmap(r => `Normalize(`RenderedDist(r)));
},
let combinationByRendering =
(evaluationParams, algebraicOp, t1: node, t2: node)
: result(node, string) => {
E.R.merge(
Render.ensureIsRenderedAndGetShape(evaluationParams, t1),
Render.ensureIsRenderedAndGetShape(evaluationParams, t2),
)
|> E.R.fmap(((a, b)) =>
`RenderedDist(Shape.combineAlgebraically(algebraicOp, a, b))
);
};
@ -92,7 +46,7 @@ module AlgebraicCombination = {
_,
fun
| `SymbolicDist(d) as t => Ok(t)
| _ => combineAsShapes(evaluationParams, algebraicOp, t1, t2),
| _ => combinationByRendering(evaluationParams, algebraicOp, t1, t2),
);
};
@ -103,7 +57,7 @@ module VerticalScaling = {
let fn = Operation.Scale.toFn(scaleOp);
let integralSumCacheFn = Operation.Scale.toIntegralSumCacheFn(scaleOp);
let integralCacheFn = Operation.Scale.toIntegralCacheFn(scaleOp);
let renderedShape = render(evaluationParams, t);
let renderedShape = Render.render(evaluationParams, t);
switch (renderedShape, scaleBy) {
| (Ok(`RenderedDist(rs)), `SymbolicDist(`Float(sm))) =>
@ -125,13 +79,22 @@ module VerticalScaling = {
module PointwiseCombination = {
let pointwiseAdd = (evaluationParams: evaluationParams, t1: t, t2: t) => {
switch (render(evaluationParams, t1), render(evaluationParams, t2)) {
switch (Render.render(evaluationParams, t1), Render.render(evaluationParams, t2)) {
| (Ok(`RenderedDist(rs1)), Ok(`RenderedDist(rs2))) =>
Ok(
`RenderedDist(
Shape.combinePointwise(
~integralSumCachesFn=(a, b) => Some(a +. b),
~integralCachesFn=(a, b) => Some(Continuous.combinePointwise(~distributionType=`CDF, (+.), a, b)),
~integralCachesFn=
(a, b) =>
Some(
Continuous.combinePointwise(
~distributionType=`CDF,
(+.),
a,
b,
),
),
(+.),
rs1,
rs2,
@ -153,7 +116,12 @@ module PointwiseCombination = {
};
let operationToLeaf =
(evaluationParams: evaluationParams, pointwiseOp: pointwiseOperation, t1: t, t2: t) => {
(
evaluationParams: evaluationParams,
pointwiseOp: pointwiseOperation,
t1: t,
t2: t,
) => {
switch (pointwiseOp) {
| `Add => pointwiseAdd(evaluationParams, t1, t2)
| `Multiply => pointwiseMultiply(evaluationParams, t1, t2)
@ -166,7 +134,7 @@ module Truncate = {
switch (leftCutoff, rightCutoff, t) {
| (None, None, t) => `Solution(t)
| (Some(lc), Some(rc), t) when lc > rc =>
`Error("Left truncation bound must be smaller than right bound.")
`Error("Left truncation bound must be smaller than right truncation bound.")
| (lc, rc, `SymbolicDist(`Uniform(u))) =>
`Solution(
`SymbolicDist(`Uniform(SymbolicDist.Uniform.truncate(lc, rc, u))),
@ -179,7 +147,7 @@ module Truncate = {
(evaluationParams: evaluationParams, leftCutoff, rightCutoff, t) => {
// TODO: use named args for xMin/xMax in renderToShape; if we're lucky we can at least get the tail
// of a distribution we otherwise wouldn't get at all
switch (render(evaluationParams, t)) {
switch (Render.ensureIsRendered(evaluationParams, t)) {
| Ok(`RenderedDist(rs)) =>
Ok(`RenderedDist(Shape.T.truncate(leftCutoff, rightCutoff, rs)))
| Error(e) => Error(e)

View File

@ -1,7 +1,13 @@
type algebraicOperation = [ | `Add | `Multiply | `Subtract | `Divide];
type pointwiseOperation = [ | `Add | `Multiply];
type scaleOperation = [ | `Multiply | `Exponentiate | `Log];
type distToFloatOperation = [ | `Pdf(float) | `Cdf(float) | `Inv(float) | `Mean | `Sample];
type distToFloatOperation = [
| `Pdf(float)
| `Cdf(float)
| `Inv(float)
| `Mean
| `Sample
];
module ExpressionTree = {
type node = [
@ -31,26 +37,42 @@ module ExpressionTree = {
let evaluateNode = (evaluationParams: evaluationParams) =>
evaluationParams.evaluateNode(evaluationParams);
let render = (evaluationParams: evaluationParams, r) =>
evaluateNode(evaluationParams, `Render(r));
let evaluateAndRetry = (evaluationParams, fn, node) =>
node
|> evaluationParams.evaluateNode(evaluationParams)
|> E.R.bind(_, fn(evaluationParams));
let renderable =
fun
| `SymbolicDist(_) => true
| `RenderedDist(_) => true
| _ => false;
module Render = {
type t = node;
let render = (evaluationParams: evaluationParams, r) =>
`Render(r) |> evaluateNode(evaluationParams);
let ensureIsRendered = (params, t) =>
switch (t) {
| `RenderedDist(_) => Ok(t)
| _ =>
switch (render(params, t)) {
| Ok(`RenderedDist(r)) => Ok(`RenderedDist(r))
| Ok(_) => Error("Did not render as requested")
| Error(e) => Error(e)
}
};
let ensureIsRenderedAndGetShape = (params, t) =>
switch (ensureIsRendered(params, t)) {
| Ok(`RenderedDist(r)) => Ok(r)
| Ok(_) => Error("Did not render as requested")
| Error(e) => Error(e)
};
let getShape = (item: node) =>
switch (item) {
| `RenderedDist(r) => Some(r)
| _ => None
};
};
let mapRenderable = (renderedFn, symFn, item: node) =>
switch (item) {
| `SymbolicDist(s) => Some(symFn(s))
| `RenderedDist(r) => Some(renderedFn(r))
| _ => None
};
};
type simplificationResult = [

View File

@ -0,0 +1,80 @@
open ExpressionTypes.ExpressionTree;
let isSamplingDistribution: node => bool =
fun
| `SymbolicDist(_) => true
| `RenderedDist(_) => true
| _ => false;
let renderIfIsNotSamplingDistribution = (params, t) =>
!isSamplingDistribution(t)
? switch (Render.render(params, t)) {
| Ok(r) => Ok(r)
| Error(e) => Error(e)
}
: Ok(t);
let map = (~renderedDistFn, ~symbolicDistFn, node: node) =>
node
|> (
fun
| `RenderedDist(r) => Some(renderedDistFn(r))
| `SymbolicDist(s) => Some(symbolicDistFn(s))
| _ => None
);
let sampleN = n =>
map(
~renderedDistFn=Shape.sampleNRendered(n),
~symbolicDistFn=SymbolicDist.T.sampleN(n),
);
let getCombinationSamples = (n, algebraicOp, t1: node, t2: node) => {
switch (sampleN(n, t1), sampleN(n, t2)) {
| (Some(a), Some(b)) =>
Some(
Belt.Array.zip(a, b)
|> E.A.fmap(((a, b)) => Operation.Algebraic.toFn(algebraicOp, a, b)),
)
| _ => None
};
};
let combineShapesUsingSampling =
(evaluationParams: evaluationParams, algebraicOp, t1: node, t2: node) => {
let i1 = renderIfIsNotSamplingDistribution(evaluationParams, t1);
let i2 = renderIfIsNotSamplingDistribution(evaluationParams, t2);
E.R.merge(i1, i2)
|> E.R.bind(
_,
((a, b)) => {
let samples =
getCombinationSamples(
evaluationParams.samplingInputs.sampleCount,
algebraicOp,
a,
b,
);
// todo: This bottom part should probably be somewhere else.
let shape =
samples
|> E.O.fmap(
Samples.T.fromSamples(
~samplingInputs={
sampleCount:
Some(evaluationParams.samplingInputs.sampleCount),
outputXYPoints:
Some(evaluationParams.samplingInputs.outputXYPoints),
kernelWidth: evaluationParams.samplingInputs.kernelWidth,
},
),
)
|> E.O.bind(_, (r: RenderTypes.ShapeRenderer.Sampling.outputs) =>
r.shape
)
|> E.O.toResult("No response");
shape |> E.R.fmap(r => `Normalize(`RenderedDist(r)));
},
);
};