Made more dists using new format

This commit is contained in:
Ozzie Gooen 2020-08-16 18:55:38 +01:00
parent c57cc3144e
commit 479fdbb491
5 changed files with 72 additions and 65 deletions

View File

@ -8,8 +8,6 @@ let rec toString: node => string =
Operation.Algebraic.format(op, toString(t1), toString(t2)) Operation.Algebraic.format(op, toString(t1), toString(t2))
| `PointwiseCombination(op, t1, t2) => | `PointwiseCombination(op, t1, t2) =>
Operation.Pointwise.format(op, toString(t1), toString(t2)) Operation.Pointwise.format(op, toString(t1), toString(t2))
| `VerticalScaling(scaleOp, t, scaleBy) =>
Operation.Scale.format(scaleOp, toString(t), toString(scaleBy))
| `Normalize(t) => "normalize(k" ++ toString(t) ++ ")" | `Normalize(t) => "normalize(k" ++ toString(t) ++ ")"
| `Truncate(lc, rc, t) => | `Truncate(lc, rc, t) =>
Operation.T.truncateToString(lc, rc, toString(t)) Operation.T.truncateToString(lc, rc, toString(t))

View File

@ -91,36 +91,6 @@ module AlgebraicCombination = {
); );
}; };
module VerticalScaling = {
let operationToLeaf =
(evaluationParams: evaluationParams, scaleOp, t, scaleBy) => {
// scaleBy has to be a single float, otherwise we'll return an error.
let fn = (secondary, main) =>
Operation.Scale.toFn(scaleOp, main, secondary);
let integralSumCacheFn = Operation.Scale.toIntegralSumCacheFn(scaleOp);
let integralCacheFn = Operation.Scale.toIntegralCacheFn(scaleOp);
let renderedShape = Render.render(evaluationParams, t);
let s =
switch (renderedShape, scaleBy) {
| (Ok(`RenderedDist(rs)), `SymbolicDist(`Float(scaleBy))) =>
Ok(
`RenderedDist(
Shape.T.mapY(
~integralSumCacheFn=integralSumCacheFn(scaleBy),
~integralCacheFn=integralCacheFn(scaleBy),
~fn=fn(scaleBy),
rs,
),
),
)
| (Error(e1), _) => Error(e1)
| (_, _) => Error("Can only scale by float values.")
};
s;
};
};
module PointwiseCombination = { module PointwiseCombination = {
let pointwiseAdd = (evaluationParams: evaluationParams, t1: t, t2: t) => { let pointwiseAdd = (evaluationParams: evaluationParams, t1: t, t2: t) => {
switch ( switch (
@ -309,8 +279,6 @@ let rec toLeaf =
t1, t1,
t2, t2,
) )
| `VerticalScaling(scaleOp, t, scaleBy) =>
VerticalScaling.operationToLeaf(evaluationParams, scaleOp, t, scaleBy)
| `Truncate(leftCutoff, rightCutoff, t) => | `Truncate(leftCutoff, rightCutoff, t) =>
Truncate.operationToLeaf(evaluationParams, leftCutoff, rightCutoff, t) Truncate.operationToLeaf(evaluationParams, leftCutoff, rightCutoff, t)
| `Normalize(t) => Normalize.operationToLeaf(evaluationParams, t) | `Normalize(t) => Normalize.operationToLeaf(evaluationParams, t)
@ -336,12 +304,7 @@ let rec toLeaf =
let components = let components =
r r
|> E.A.fmap(((dist, weight)) => |> E.A.fmap(((dist, weight)) =>
`VerticalScaling(( `FunctionCall("scaleExp", [|dist, `SymbolicDist(`Float(weight))|]));
`Multiply,
dist,
`SymbolicDist(`Float(weight)),
))
);
let pointwiseSum = let pointwiseSum =
components components
|> Js.Array.sliceFrom(1) |> Js.Array.sliceFrom(1)

View File

@ -25,7 +25,6 @@ module ExpressionTree = {
| `Function(array(string), node) | `Function(array(string), node)
| `AlgebraicCombination(algebraicOperation, node, node) | `AlgebraicCombination(algebraicOperation, node, node)
| `PointwiseCombination(pointwiseOperation, node, node) | `PointwiseCombination(pointwiseOperation, node, node)
| `VerticalScaling(scaleOperation, node, node)
| `Normalize(node) | `Normalize(node)
| `Render(node) | `Render(node)
| `Truncate(option(float), option(float), node) | `Truncate(option(float), option(float), node)

View File

@ -1,6 +1,9 @@
open TypeSystem; open TypeSystem;
let wrongInputsError = (r) => {Js.log2("Wrong inputs", r); Error("Wrong inputs")}; let wrongInputsError = r => {
Js.log2("Wrong inputs", r);
Error("Wrong inputs");
};
let to_: (float, float) => result(node, string) = let to_: (float, float) => result(node, string) =
(low, high) => (low, high) =>
@ -20,7 +23,7 @@ let makeSymbolicFromTwoFloats = (name, fn) =>
~run= ~run=
fun fun
| [|`Float(a), `Float(b)|] => Ok(`SymbolicDist(fn(a, b))) | [|`Float(a), `Float(b)|] => Ok(`SymbolicDist(fn(a, b)))
| e => wrongInputsError(e) | e => wrongInputsError(e),
); );
let makeSymbolicFromOneFloat = (name, fn) => let makeSymbolicFromOneFloat = (name, fn) =>
@ -31,21 +34,32 @@ let makeSymbolicFromOneFloat = (name, fn) =>
~run= ~run=
fun fun
| [|`Float(a)|] => Ok(`SymbolicDist(fn(a))) | [|`Float(a)|] => Ok(`SymbolicDist(fn(a)))
| e => wrongInputsError(e) | e => wrongInputsError(e),
); );
let makeDistFloat = (name, fn) => let makeDistFloat = (name, fn) =>
Function.make( Function.make(
~name, ~name,
~output=`SamplingDistribution, ~output=`SamplingDistribution,
~inputs=[|`SamplingDistribution, `Float|], ~inputs=[|`SamplingDistribution, `Float|],
~run= ~run=
fun fun
| [|`SamplingDist(a), `Float(b)|] => (fn(a,b)) | [|`SamplingDist(a), `Float(b)|] => fn(a, b)
| e => wrongInputsError(e) | e => wrongInputsError(e),
); );
let makeDist = (name, fn) => let makeRenderedDistFloat = (name, fn) =>
Function.make(
~name,
~output=`RenderedDistribution,
~inputs=[|`RenderedDistribution, `Float|],
~run=
fun
| [|`RenderedDist(a), `Float(b)|] => fn(a, b)
| e => wrongInputsError(e),
);
let makeDist = (name, fn) =>
Function.make( Function.make(
~name, ~name,
~output=`SamplingDistribution, ~output=`SamplingDistribution,
@ -53,7 +67,7 @@ let makeDist = (name, fn) =>
~run= ~run=
fun fun
| [|`SamplingDist(a)|] => fn(a) | [|`SamplingDist(a)|] => fn(a)
| e => wrongInputsError(e) | e => wrongInputsError(e),
); );
let floatFromDist = let floatFromDist =
@ -71,6 +85,22 @@ let floatFromDist =
}; };
}; };
let verticalScaling = (scaleOp, rs, scaleBy) => {
// scaleBy has to be a single float, otherwise we'll return an error.
let fn = (secondary, main) =>
Operation.Scale.toFn(scaleOp, main, secondary);
let integralSumCacheFn = Operation.Scale.toIntegralSumCacheFn(scaleOp);
let integralCacheFn = Operation.Scale.toIntegralCacheFn(scaleOp);
Ok(`RenderedDist(
Shape.T.mapY(
~integralSumCacheFn=integralSumCacheFn(scaleBy),
~integralCacheFn=integralCacheFn(scaleBy),
~fn=fn(scaleBy),
rs,
),
));
};
let functions = [| let functions = [|
makeSymbolicFromTwoFloats("normal", SymbolicDist.Normal.make), makeSymbolicFromTwoFloats("normal", SymbolicDist.Normal.make),
makeSymbolicFromTwoFloats("uniform", SymbolicDist.Uniform.make), makeSymbolicFromTwoFloats("uniform", SymbolicDist.Uniform.make),
@ -87,8 +117,8 @@ let functions = [|
~inputs=[|`Float, `Float|], ~inputs=[|`Float, `Float|],
~run= ~run=
fun fun
| [|`Float(a), `Float(b)|] => to_(a,b) | [|`Float(a), `Float(b)|] => to_(a, b)
| e => wrongInputsError(e) | e => wrongInputsError(e),
), ),
Function.make( Function.make(
~name="triangular", ~name="triangular",
@ -99,11 +129,39 @@ let functions = [|
| [|`Float(a), `Float(b), `Float(c)|] => | [|`Float(a), `Float(b), `Float(c)|] =>
SymbolicDist.Triangular.make(a, b, c) SymbolicDist.Triangular.make(a, b, c)
|> E.R.fmap(r => `SymbolicDist(r)) |> E.R.fmap(r => `SymbolicDist(r))
| e => wrongInputsError(e) | e => wrongInputsError(e),
), ),
makeDistFloat("pdf", (dist, float) => floatFromDist(`Pdf(float), dist)), makeDistFloat("pdf", (dist, float) => floatFromDist(`Pdf(float), dist)),
makeDistFloat("inv", (dist, float) => floatFromDist(`Inv(float), dist)), makeDistFloat("inv", (dist, float) => floatFromDist(`Inv(float), dist)),
makeDistFloat("cdf", (dist, float) => floatFromDist(`Cdf(float), dist)), makeDistFloat("cdf", (dist, float) => floatFromDist(`Cdf(float), dist)),
makeDist("mean", (dist) => floatFromDist(`Mean, dist)), makeDist("mean", dist => floatFromDist(`Mean, dist)),
makeDist("sample", (dist) => floatFromDist(`Sample, dist)) makeDist("sample", dist => floatFromDist(`Sample, dist)),
Function.make(
~name="render",
~output=`RenderedDistribution,
~inputs=[|`RenderedDistribution|],
~run=
fun
| [|`RenderedDist(c)|] => Ok(`RenderedDist(c))
| e => wrongInputsError(e),
),
Function.make(
~name="normalize",
~output=`SamplingDistribution,
~inputs=[|`SamplingDistribution|],
~run=
fun
| [|`SamplingDist(`SymbolicDist(c))|] => Ok(`SymbolicDist(c))
| [|`SamplingDist(`RenderedDist(c))|] => Ok(`RenderedDist(Shape.T.normalize(c)))
| e => wrongInputsError(e),
),
makeRenderedDistFloat("scaleExp", (dist, float) =>
verticalScaling(`Exponentiate, dist, float)
),
makeRenderedDistFloat("scaleMultiply", (dist, float) =>
verticalScaling(`Multiply, dist, float)
),
makeRenderedDistFloat("scaleLog", (dist, float) =>
verticalScaling(`Log, dist, float)
),
|]; |];

View File

@ -193,14 +193,6 @@ module MathAdtToDistDst = {
Error( Error(
"truncate needs three arguments: the expression and both cutoffs", "truncate needs three arguments: the expression and both cutoffs",
) )
| ("scaleMultiply", [|d, `SymbolicDist(`Float(v))|]) =>
Ok(`VerticalScaling((`Multiply, d, `SymbolicDist(`Float(v)))))
| ("scaleExp", [|d, `SymbolicDist(`Float(v))|]) =>
Ok(
`VerticalScaling((`Exponentiate, d, `SymbolicDist(`Float(v)))),
)
| ("scaleLog", [|d, `SymbolicDist(`Float(v))|]) =>
Ok(`VerticalScaling((`Log, d, `SymbolicDist(`Float(v)))))
| _ => Error("This type not currently supported") | _ => Error("This type not currently supported")
} }
}); });
@ -245,9 +237,6 @@ module MathAdtToDistDst = {
| "pow" | "pow"
| "leftTruncate" | "leftTruncate"
| "rightTruncate" | "rightTruncate"
| "scaleMultiply"
| "scaleExp"
| "scaleLog"
| "truncate" => operationParser(name, parseArgs()) | "truncate" => operationParser(name, parseArgs())
| name => | name =>
parseArgs() parseArgs()