Made more dists using new format
This commit is contained in:
parent
c57cc3144e
commit
479fdbb491
|
@ -8,8 +8,6 @@ let rec toString: node => string =
|
||||||
Operation.Algebraic.format(op, toString(t1), toString(t2))
|
Operation.Algebraic.format(op, toString(t1), toString(t2))
|
||||||
| `PointwiseCombination(op, t1, t2) =>
|
| `PointwiseCombination(op, t1, t2) =>
|
||||||
Operation.Pointwise.format(op, toString(t1), toString(t2))
|
Operation.Pointwise.format(op, toString(t1), toString(t2))
|
||||||
| `VerticalScaling(scaleOp, t, scaleBy) =>
|
|
||||||
Operation.Scale.format(scaleOp, toString(t), toString(scaleBy))
|
|
||||||
| `Normalize(t) => "normalize(k" ++ toString(t) ++ ")"
|
| `Normalize(t) => "normalize(k" ++ toString(t) ++ ")"
|
||||||
| `Truncate(lc, rc, t) =>
|
| `Truncate(lc, rc, t) =>
|
||||||
Operation.T.truncateToString(lc, rc, toString(t))
|
Operation.T.truncateToString(lc, rc, toString(t))
|
||||||
|
|
|
@ -91,36 +91,6 @@ module AlgebraicCombination = {
|
||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
|
||||||
module VerticalScaling = {
|
|
||||||
let operationToLeaf =
|
|
||||||
(evaluationParams: evaluationParams, scaleOp, t, scaleBy) => {
|
|
||||||
// scaleBy has to be a single float, otherwise we'll return an error.
|
|
||||||
let fn = (secondary, main) =>
|
|
||||||
Operation.Scale.toFn(scaleOp, main, secondary);
|
|
||||||
let integralSumCacheFn = Operation.Scale.toIntegralSumCacheFn(scaleOp);
|
|
||||||
let integralCacheFn = Operation.Scale.toIntegralCacheFn(scaleOp);
|
|
||||||
let renderedShape = Render.render(evaluationParams, t);
|
|
||||||
|
|
||||||
let s =
|
|
||||||
switch (renderedShape, scaleBy) {
|
|
||||||
| (Ok(`RenderedDist(rs)), `SymbolicDist(`Float(scaleBy))) =>
|
|
||||||
Ok(
|
|
||||||
`RenderedDist(
|
|
||||||
Shape.T.mapY(
|
|
||||||
~integralSumCacheFn=integralSumCacheFn(scaleBy),
|
|
||||||
~integralCacheFn=integralCacheFn(scaleBy),
|
|
||||||
~fn=fn(scaleBy),
|
|
||||||
rs,
|
|
||||||
),
|
|
||||||
),
|
|
||||||
)
|
|
||||||
| (Error(e1), _) => Error(e1)
|
|
||||||
| (_, _) => Error("Can only scale by float values.")
|
|
||||||
};
|
|
||||||
s;
|
|
||||||
};
|
|
||||||
};
|
|
||||||
|
|
||||||
module PointwiseCombination = {
|
module PointwiseCombination = {
|
||||||
let pointwiseAdd = (evaluationParams: evaluationParams, t1: t, t2: t) => {
|
let pointwiseAdd = (evaluationParams: evaluationParams, t1: t, t2: t) => {
|
||||||
switch (
|
switch (
|
||||||
|
@ -309,8 +279,6 @@ let rec toLeaf =
|
||||||
t1,
|
t1,
|
||||||
t2,
|
t2,
|
||||||
)
|
)
|
||||||
| `VerticalScaling(scaleOp, t, scaleBy) =>
|
|
||||||
VerticalScaling.operationToLeaf(evaluationParams, scaleOp, t, scaleBy)
|
|
||||||
| `Truncate(leftCutoff, rightCutoff, t) =>
|
| `Truncate(leftCutoff, rightCutoff, t) =>
|
||||||
Truncate.operationToLeaf(evaluationParams, leftCutoff, rightCutoff, t)
|
Truncate.operationToLeaf(evaluationParams, leftCutoff, rightCutoff, t)
|
||||||
| `Normalize(t) => Normalize.operationToLeaf(evaluationParams, t)
|
| `Normalize(t) => Normalize.operationToLeaf(evaluationParams, t)
|
||||||
|
@ -336,12 +304,7 @@ let rec toLeaf =
|
||||||
let components =
|
let components =
|
||||||
r
|
r
|
||||||
|> E.A.fmap(((dist, weight)) =>
|
|> E.A.fmap(((dist, weight)) =>
|
||||||
`VerticalScaling((
|
`FunctionCall("scaleExp", [|dist, `SymbolicDist(`Float(weight))|]));
|
||||||
`Multiply,
|
|
||||||
dist,
|
|
||||||
`SymbolicDist(`Float(weight)),
|
|
||||||
))
|
|
||||||
);
|
|
||||||
let pointwiseSum =
|
let pointwiseSum =
|
||||||
components
|
components
|
||||||
|> Js.Array.sliceFrom(1)
|
|> Js.Array.sliceFrom(1)
|
||||||
|
|
|
@ -25,7 +25,6 @@ module ExpressionTree = {
|
||||||
| `Function(array(string), node)
|
| `Function(array(string), node)
|
||||||
| `AlgebraicCombination(algebraicOperation, node, node)
|
| `AlgebraicCombination(algebraicOperation, node, node)
|
||||||
| `PointwiseCombination(pointwiseOperation, node, node)
|
| `PointwiseCombination(pointwiseOperation, node, node)
|
||||||
| `VerticalScaling(scaleOperation, node, node)
|
|
||||||
| `Normalize(node)
|
| `Normalize(node)
|
||||||
| `Render(node)
|
| `Render(node)
|
||||||
| `Truncate(option(float), option(float), node)
|
| `Truncate(option(float), option(float), node)
|
||||||
|
|
|
@ -1,6 +1,9 @@
|
||||||
open TypeSystem;
|
open TypeSystem;
|
||||||
|
|
||||||
let wrongInputsError = (r) => {Js.log2("Wrong inputs", r); Error("Wrong inputs")};
|
let wrongInputsError = r => {
|
||||||
|
Js.log2("Wrong inputs", r);
|
||||||
|
Error("Wrong inputs");
|
||||||
|
};
|
||||||
|
|
||||||
let to_: (float, float) => result(node, string) =
|
let to_: (float, float) => result(node, string) =
|
||||||
(low, high) =>
|
(low, high) =>
|
||||||
|
@ -20,7 +23,7 @@ let makeSymbolicFromTwoFloats = (name, fn) =>
|
||||||
~run=
|
~run=
|
||||||
fun
|
fun
|
||||||
| [|`Float(a), `Float(b)|] => Ok(`SymbolicDist(fn(a, b)))
|
| [|`Float(a), `Float(b)|] => Ok(`SymbolicDist(fn(a, b)))
|
||||||
| e => wrongInputsError(e)
|
| e => wrongInputsError(e),
|
||||||
);
|
);
|
||||||
|
|
||||||
let makeSymbolicFromOneFloat = (name, fn) =>
|
let makeSymbolicFromOneFloat = (name, fn) =>
|
||||||
|
@ -31,21 +34,32 @@ let makeSymbolicFromOneFloat = (name, fn) =>
|
||||||
~run=
|
~run=
|
||||||
fun
|
fun
|
||||||
| [|`Float(a)|] => Ok(`SymbolicDist(fn(a)))
|
| [|`Float(a)|] => Ok(`SymbolicDist(fn(a)))
|
||||||
| e => wrongInputsError(e)
|
| e => wrongInputsError(e),
|
||||||
);
|
);
|
||||||
|
|
||||||
let makeDistFloat = (name, fn) =>
|
let makeDistFloat = (name, fn) =>
|
||||||
Function.make(
|
Function.make(
|
||||||
~name,
|
~name,
|
||||||
~output=`SamplingDistribution,
|
~output=`SamplingDistribution,
|
||||||
~inputs=[|`SamplingDistribution, `Float|],
|
~inputs=[|`SamplingDistribution, `Float|],
|
||||||
~run=
|
~run=
|
||||||
fun
|
fun
|
||||||
| [|`SamplingDist(a), `Float(b)|] => (fn(a,b))
|
| [|`SamplingDist(a), `Float(b)|] => fn(a, b)
|
||||||
| e => wrongInputsError(e)
|
| e => wrongInputsError(e),
|
||||||
);
|
);
|
||||||
|
|
||||||
let makeDist = (name, fn) =>
|
let makeRenderedDistFloat = (name, fn) =>
|
||||||
|
Function.make(
|
||||||
|
~name,
|
||||||
|
~output=`RenderedDistribution,
|
||||||
|
~inputs=[|`RenderedDistribution, `Float|],
|
||||||
|
~run=
|
||||||
|
fun
|
||||||
|
| [|`RenderedDist(a), `Float(b)|] => fn(a, b)
|
||||||
|
| e => wrongInputsError(e),
|
||||||
|
);
|
||||||
|
|
||||||
|
let makeDist = (name, fn) =>
|
||||||
Function.make(
|
Function.make(
|
||||||
~name,
|
~name,
|
||||||
~output=`SamplingDistribution,
|
~output=`SamplingDistribution,
|
||||||
|
@ -53,7 +67,7 @@ let makeDist = (name, fn) =>
|
||||||
~run=
|
~run=
|
||||||
fun
|
fun
|
||||||
| [|`SamplingDist(a)|] => fn(a)
|
| [|`SamplingDist(a)|] => fn(a)
|
||||||
| e => wrongInputsError(e)
|
| e => wrongInputsError(e),
|
||||||
);
|
);
|
||||||
|
|
||||||
let floatFromDist =
|
let floatFromDist =
|
||||||
|
@ -71,6 +85,22 @@ let floatFromDist =
|
||||||
};
|
};
|
||||||
};
|
};
|
||||||
|
|
||||||
|
let verticalScaling = (scaleOp, rs, scaleBy) => {
|
||||||
|
// scaleBy has to be a single float, otherwise we'll return an error.
|
||||||
|
let fn = (secondary, main) =>
|
||||||
|
Operation.Scale.toFn(scaleOp, main, secondary);
|
||||||
|
let integralSumCacheFn = Operation.Scale.toIntegralSumCacheFn(scaleOp);
|
||||||
|
let integralCacheFn = Operation.Scale.toIntegralCacheFn(scaleOp);
|
||||||
|
Ok(`RenderedDist(
|
||||||
|
Shape.T.mapY(
|
||||||
|
~integralSumCacheFn=integralSumCacheFn(scaleBy),
|
||||||
|
~integralCacheFn=integralCacheFn(scaleBy),
|
||||||
|
~fn=fn(scaleBy),
|
||||||
|
rs,
|
||||||
|
),
|
||||||
|
));
|
||||||
|
};
|
||||||
|
|
||||||
let functions = [|
|
let functions = [|
|
||||||
makeSymbolicFromTwoFloats("normal", SymbolicDist.Normal.make),
|
makeSymbolicFromTwoFloats("normal", SymbolicDist.Normal.make),
|
||||||
makeSymbolicFromTwoFloats("uniform", SymbolicDist.Uniform.make),
|
makeSymbolicFromTwoFloats("uniform", SymbolicDist.Uniform.make),
|
||||||
|
@ -87,8 +117,8 @@ let functions = [|
|
||||||
~inputs=[|`Float, `Float|],
|
~inputs=[|`Float, `Float|],
|
||||||
~run=
|
~run=
|
||||||
fun
|
fun
|
||||||
| [|`Float(a), `Float(b)|] => to_(a,b)
|
| [|`Float(a), `Float(b)|] => to_(a, b)
|
||||||
| e => wrongInputsError(e)
|
| e => wrongInputsError(e),
|
||||||
),
|
),
|
||||||
Function.make(
|
Function.make(
|
||||||
~name="triangular",
|
~name="triangular",
|
||||||
|
@ -99,11 +129,39 @@ let functions = [|
|
||||||
| [|`Float(a), `Float(b), `Float(c)|] =>
|
| [|`Float(a), `Float(b), `Float(c)|] =>
|
||||||
SymbolicDist.Triangular.make(a, b, c)
|
SymbolicDist.Triangular.make(a, b, c)
|
||||||
|> E.R.fmap(r => `SymbolicDist(r))
|
|> E.R.fmap(r => `SymbolicDist(r))
|
||||||
| e => wrongInputsError(e)
|
| e => wrongInputsError(e),
|
||||||
),
|
),
|
||||||
makeDistFloat("pdf", (dist, float) => floatFromDist(`Pdf(float), dist)),
|
makeDistFloat("pdf", (dist, float) => floatFromDist(`Pdf(float), dist)),
|
||||||
makeDistFloat("inv", (dist, float) => floatFromDist(`Inv(float), dist)),
|
makeDistFloat("inv", (dist, float) => floatFromDist(`Inv(float), dist)),
|
||||||
makeDistFloat("cdf", (dist, float) => floatFromDist(`Cdf(float), dist)),
|
makeDistFloat("cdf", (dist, float) => floatFromDist(`Cdf(float), dist)),
|
||||||
makeDist("mean", (dist) => floatFromDist(`Mean, dist)),
|
makeDist("mean", dist => floatFromDist(`Mean, dist)),
|
||||||
makeDist("sample", (dist) => floatFromDist(`Sample, dist))
|
makeDist("sample", dist => floatFromDist(`Sample, dist)),
|
||||||
|
Function.make(
|
||||||
|
~name="render",
|
||||||
|
~output=`RenderedDistribution,
|
||||||
|
~inputs=[|`RenderedDistribution|],
|
||||||
|
~run=
|
||||||
|
fun
|
||||||
|
| [|`RenderedDist(c)|] => Ok(`RenderedDist(c))
|
||||||
|
| e => wrongInputsError(e),
|
||||||
|
),
|
||||||
|
Function.make(
|
||||||
|
~name="normalize",
|
||||||
|
~output=`SamplingDistribution,
|
||||||
|
~inputs=[|`SamplingDistribution|],
|
||||||
|
~run=
|
||||||
|
fun
|
||||||
|
| [|`SamplingDist(`SymbolicDist(c))|] => Ok(`SymbolicDist(c))
|
||||||
|
| [|`SamplingDist(`RenderedDist(c))|] => Ok(`RenderedDist(Shape.T.normalize(c)))
|
||||||
|
| e => wrongInputsError(e),
|
||||||
|
),
|
||||||
|
makeRenderedDistFloat("scaleExp", (dist, float) =>
|
||||||
|
verticalScaling(`Exponentiate, dist, float)
|
||||||
|
),
|
||||||
|
makeRenderedDistFloat("scaleMultiply", (dist, float) =>
|
||||||
|
verticalScaling(`Multiply, dist, float)
|
||||||
|
),
|
||||||
|
makeRenderedDistFloat("scaleLog", (dist, float) =>
|
||||||
|
verticalScaling(`Log, dist, float)
|
||||||
|
),
|
||||||
|];
|
|];
|
||||||
|
|
|
@ -193,14 +193,6 @@ module MathAdtToDistDst = {
|
||||||
Error(
|
Error(
|
||||||
"truncate needs three arguments: the expression and both cutoffs",
|
"truncate needs three arguments: the expression and both cutoffs",
|
||||||
)
|
)
|
||||||
| ("scaleMultiply", [|d, `SymbolicDist(`Float(v))|]) =>
|
|
||||||
Ok(`VerticalScaling((`Multiply, d, `SymbolicDist(`Float(v)))))
|
|
||||||
| ("scaleExp", [|d, `SymbolicDist(`Float(v))|]) =>
|
|
||||||
Ok(
|
|
||||||
`VerticalScaling((`Exponentiate, d, `SymbolicDist(`Float(v)))),
|
|
||||||
)
|
|
||||||
| ("scaleLog", [|d, `SymbolicDist(`Float(v))|]) =>
|
|
||||||
Ok(`VerticalScaling((`Log, d, `SymbolicDist(`Float(v)))))
|
|
||||||
| _ => Error("This type not currently supported")
|
| _ => Error("This type not currently supported")
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
@ -245,9 +237,6 @@ module MathAdtToDistDst = {
|
||||||
| "pow"
|
| "pow"
|
||||||
| "leftTruncate"
|
| "leftTruncate"
|
||||||
| "rightTruncate"
|
| "rightTruncate"
|
||||||
| "scaleMultiply"
|
|
||||||
| "scaleExp"
|
|
||||||
| "scaleLog"
|
|
||||||
| "truncate" => operationParser(name, parseArgs())
|
| "truncate" => operationParser(name, parseArgs())
|
||||||
| name =>
|
| name =>
|
||||||
parseArgs()
|
parseArgs()
|
||||||
|
|
Loading…
Reference in New Issue
Block a user