Moved samplingDistribution to PTypes

This commit is contained in:
Ozzie Gooen 2020-08-06 17:08:26 +01:00
parent f9607006d5
commit f72a3ddde6
4 changed files with 95 additions and 77 deletions

View File

@ -49,11 +49,11 @@ module AlgebraicCombination = {
(evaluationParams, algebraicOp, t1: node, t2: node)
: result(node, string) => {
E.R.merge(
SamplingDistribution.renderIfIsNotSamplingDistribution(
PTypes.SamplingDistribution.renderIfIsNotSamplingDistribution(
evaluationParams,
t1,
),
SamplingDistribution.renderIfIsNotSamplingDistribution(
PTypes.SamplingDistribution.renderIfIsNotSamplingDistribution(
evaluationParams,
t2,
),
@ -61,7 +61,7 @@ module AlgebraicCombination = {
|> E.R.bind(_, ((a, b)) =>
switch (choose(a, b)) {
| `Sampling =>
SamplingDistribution.combineShapesUsingSampling(
PTypes.SamplingDistribution.combineShapesUsingSampling(
evaluationParams,
algebraicOp,
a,
@ -296,6 +296,21 @@ let run = (node, fnNode) => {
};
};
// let outputType = (t:t) => t |> fun
// | `SymbolicDist(_) => `RenderedDist
// | `RenderedDist(_) => `RenderedDist
// | `Bool(_) => `RenderedDist
// | `AlgebraicCombination(_) => `RenderedDist
// | `PointwiseCombination(_) => `RenderedDist
// | `VerticalScaling(_) => `RenderedDist
// | `Truncate(_) => `RenderedDist
// | `FloatFromDist(_) => `RenderedDist
// | `Normalize(_) => `RenderedDist
// | `Render(_) => `RenderedDist
// | `Function(_) => `Function
// | `FunctionCall(_) => `Any
// | `Symbol(_) => `Any
/* This function recursively goes through the nodes of the parse tree,
replacing each Operation node and its subtree with a Data node.
Whenever possible, the replacement produces a new Symbolic Data node,
@ -335,6 +350,7 @@ let toLeaf =
FloatFromDist.operationToLeaf(evaluationParams, distToFloatOp, t)
| `Normalize(t) => Normalize.operationToLeaf(evaluationParams, t)
| `Render(t) => Render.operationToLeaf(evaluationParams, t)
| `Function(_) => Error("Function must be called with params")
| `Symbol(r) => ExpressionTypes.ExpressionTree.Environment.get(evaluationParams.environment, r) |> E.O.toResult("Undeclared variable " ++ r)
| `FunctionCall(name, args) =>

View File

@ -13,6 +13,8 @@ module ExpressionTree = {
type node = [
| `SymbolicDist(SymbolicTypes.symbolicDist)
| `RenderedDist(DistTypes.shape)
| `Symbol(string)
| `Function(array(string), node)
| `AlgebraicCombination(algebraicOperation, node, node)
| `PointwiseCombination(pointwiseOperation, node, node)
| `VerticalScaling(scaleOperation, node, node)
@ -20,9 +22,7 @@ module ExpressionTree = {
| `Truncate(option(float), option(float), node)
| `Normalize(node)
| `FloatFromDist(distToFloatOperation, node)
| `Function(array(string), node)
| `FunctionCall(string, array(node))
| `Symbol(string)
];
type samplingInputs = {

View File

@ -0,0 +1,74 @@
open ExpressionTypes.ExpressionTree;
module SamplingDistribution = {
let isSamplingDistribution: node => bool =
fun
| `SymbolicDist(_) => true
| `RenderedDist(_) => true
| _ => false;
let renderIfIsNotSamplingDistribution = (params, t): result(node, string) =>
!isSamplingDistribution(t)
? switch (Render.render(params, t)) {
| Ok(r) => Ok(r)
| Error(e) => Error(e)
}
: Ok(t);
let map = (~renderedDistFn, ~symbolicDistFn, node: node) =>
node
|> (
fun
| `RenderedDist(r) => Some(renderedDistFn(r))
| `SymbolicDist(s) => Some(symbolicDistFn(s))
| _ => None
);
let sampleN = n =>
map(
~renderedDistFn=Shape.sampleNRendered(n),
~symbolicDistFn=SymbolicDist.T.sampleN(n),
);
let getCombinationSamples = (n, algebraicOp, t1: node, t2: node) => {
switch (sampleN(n, t1), sampleN(n, t2)) {
| (Some(a), Some(b)) =>
Some(
Belt.Array.zip(a, b)
|> E.A.fmap(((a, b)) => Operation.Algebraic.toFn(algebraicOp, a, b)),
)
| _ => None
};
};
let combineShapesUsingSampling =
(evaluationParams: evaluationParams, algebraicOp, t1: node, t2: node) => {
let i1 = renderIfIsNotSamplingDistribution(evaluationParams, t1);
let i2 = renderIfIsNotSamplingDistribution(evaluationParams, t2);
E.R.merge(i1, i2)
|> E.R.bind(
_,
((a, b)) => {
let samples =
getCombinationSamples(
evaluationParams.samplingInputs.sampleCount,
algebraicOp,
a,
b,
);
// todo: This bottom part should probably be somewhere else.
let shape =
samples
|> E.O.fmap(
SamplesToShape.fromSamples(
~samplingInputs=evaluationParams.samplingInputs,
),
)
|> E.O.bind(_, r => r.shape)
|> E.O.toResult("No response");
shape |> E.R.fmap(r => `Normalize(`RenderedDist(r)));
},
);
};
};

View File

@ -1,72 +0,0 @@
open ExpressionTypes.ExpressionTree;
let isSamplingDistribution: node => bool =
fun
| `SymbolicDist(_) => true
| `RenderedDist(_) => true
| _ => false;
let renderIfIsNotSamplingDistribution = (params, t): result(node, string) =>
!isSamplingDistribution(t)
? switch (Render.render(params, t)) {
| Ok(r) => Ok(r)
| Error(e) => Error(e)
}
: Ok(t);
let map = (~renderedDistFn, ~symbolicDistFn, node: node) =>
node
|> (
fun
| `RenderedDist(r) => Some(renderedDistFn(r))
| `SymbolicDist(s) => Some(symbolicDistFn(s))
| _ => None
);
let sampleN = n =>
map(
~renderedDistFn=Shape.sampleNRendered(n),
~symbolicDistFn=SymbolicDist.T.sampleN(n),
);
let getCombinationSamples = (n, algebraicOp, t1: node, t2: node) => {
switch (sampleN(n, t1), sampleN(n, t2)) {
| (Some(a), Some(b)) =>
Some(
Belt.Array.zip(a, b)
|> E.A.fmap(((a, b)) => Operation.Algebraic.toFn(algebraicOp, a, b)),
)
| _ => None
};
};
let combineShapesUsingSampling =
(evaluationParams: evaluationParams, algebraicOp, t1: node, t2: node) => {
let i1 = renderIfIsNotSamplingDistribution(evaluationParams, t1);
let i2 = renderIfIsNotSamplingDistribution(evaluationParams, t2);
E.R.merge(i1, i2)
|> E.R.bind(
_,
((a, b)) => {
let samples =
getCombinationSamples(
evaluationParams.samplingInputs.sampleCount,
algebraicOp,
a,
b,
);
// todo: This bottom part should probably be somewhere else.
let shape =
samples
|> E.O.fmap(
SamplesToShape.fromSamples(
~samplingInputs=evaluationParams.samplingInputs,
),
)
|> E.O.bind(_, r => r.shape)
|> E.O.toResult("No response");
shape |> E.R.fmap(r => `Normalize(`RenderedDist(r)));
},
);
};