Added simple exponentiation

This commit is contained in:
Ozzie Gooen 2020-07-31 14:02:29 +01:00
parent bcbbcf0f43
commit e266e6557c
6 changed files with 54 additions and 35 deletions

View File

@ -45,11 +45,8 @@ let toDiscretePointMassesFromTriangulars =
if (inverse) {
for (i in 1 to n - 2) {
Belt.Array.set(
masses,
i - 1,
(xs[i + 1] -. xs[i - 1]) *. ys[i] /. 2.,
) |> ignore;
Belt.Array.set(masses, i - 1, (xs[i + 1] -. xs[i - 1]) *. ys[i] /. 2.)
|> ignore;
// this only works when the whole triange is either on the left or on the right of zero
let a = xs[i - 1];
@ -80,14 +77,12 @@ let toDiscretePointMassesFromTriangulars =
} else {
for (i in 1 to n - 2) {
// area of triangle = width * height / 2
Belt.Array.set(
masses,
i - 1,
(xs[i + 1] -. xs[i - 1]) *. ys[i] /. 2.,
) |> ignore;
Belt.Array.set(masses, i - 1, (xs[i + 1] -. xs[i - 1]) *. ys[i] /. 2.)
|> ignore;
// means of triangle = (a + b + c) / 3
Belt.Array.set(means, i - 1, (xs[i - 1] +. xs[i] +. xs[i + 1]) /. 3.) |> ignore;
Belt.Array.set(means, i - 1, (xs[i - 1] +. xs[i] +. xs[i + 1]) /. 3.)
|> ignore;
// variance of triangle = (a^2 + b^2 + c^2 - ab - ac - bc) / 18
Belt.Array.set(
@ -102,7 +97,8 @@ let toDiscretePointMassesFromTriangulars =
-. xsProdN2[i - 1]
)
/. 18.,
) |> ignore;
)
|> ignore;
();
};
{n: n - 2, masses, means, variances};
@ -134,8 +130,10 @@ let combineShapesContinuousContinuous =
| `Subtract => ((m1, m2) => m1 -. m2)
| `Multiply => ((m1, m2) => m1 *. m2)
| `Divide => ((m1, mInv2) => m1 *. mInv2)
| `Exponentiate => ((m1, mInv2) => m1 ** mInv2)
}; // note: here, mInv2 = mean(1 / t2) ~= 1 / mean(t2)
// TODO: I don't know what the variances are for exponentatiation
// converts the variances and means of the two inputs into the variance of the output
let combineVariancesFn =
switch (op) {
@ -144,6 +142,8 @@ let combineShapesContinuousContinuous =
| `Multiply => (
(v1, v2, m1, m2) => v1 *. v2 +. v1 *. m2 ** 2. +. v2 *. m1 ** 2.
)
| `Exponentiate =>
((v1, v2, m1, m2) => v1 *. v2 +. v1 *. m2 ** 2. +. v2 *. m1 ** 2.);
| `Divide => (
(v1, vInv2, m1, mInv2) =>
v1 *. vInv2 +. v1 *. mInv2 ** 2. +. vInv2 *. m1 ** 2.
@ -225,9 +225,12 @@ let toDiscretePointMassesFromDiscrete =
};
let combineShapesContinuousDiscrete =
(op: ExpressionTypes.algebraicOperation, continuousShape: DistTypes.xyShape, discreteShape: DistTypes.xyShape)
(
op: ExpressionTypes.algebraicOperation,
continuousShape: DistTypes.xyShape,
discreteShape: DistTypes.xyShape,
)
: DistTypes.xyShape => {
let t1n = continuousShape |> XYShape.T.length;
let t2n = discreteShape |> XYShape.T.length;
@ -248,15 +251,19 @@ let combineShapesContinuousDiscrete =
Belt.Array.set(
dxyShape,
i,
(fn(continuousShape.xs[i], discreteShape.xs[j]),
continuousShape.ys[i] *. discreteShape.ys[j]),
) |> ignore;
(
fn(continuousShape.xs[i], discreteShape.xs[j]),
continuousShape.ys[i] *. discreteShape.ys[j],
),
)
|> ignore;
();
};
Belt.Array.set(outXYShapes, j, dxyShape) |> ignore;
();
}
| `Multiply
| `Exponentiate
| `Divide =>
for (j in 0 to t2n - 1) {
// creates a new continuous shape for each one of the discrete points, and collects them in outXYShapes.
@ -266,8 +273,12 @@ let combineShapesContinuousDiscrete =
Belt.Array.set(
dxyShape,
i,
(fn(continuousShape.xs[i], discreteShape.xs[j]), continuousShape.ys[i] *. discreteShape.ys[j] /. discreteShape.xs[j]),
) |> ignore;
(
fn(continuousShape.xs[i], discreteShape.xs[j]),
{continuousShape.ys[i] *. discreteShape.ys[j] /. discreteShape.xs[j]}
),
)
|> ignore;
();
};
Belt.Array.set(outXYShapes, j, dxyShape) |> ignore;
@ -278,7 +289,10 @@ let combineShapesContinuousDiscrete =
outXYShapes
|> E.A.fmap(XYShape.T.fromZippedArray)
|> E.A.fold_left(
XYShape.PointwiseCombination.combine((+.),
XYShape.XtoY.continuousInterpolator(`Linear, `UseZero)),
XYShape.T.empty);
XYShape.PointwiseCombination.combine(
(+.),
XYShape.XtoY.continuousInterpolator(`Linear, `UseZero),
),
XYShape.T.empty,
);
};

View File

@ -151,7 +151,7 @@ module PointwiseCombination = {
};
};
let pointwiseMultiply = (evaluationParams: evaluationParams, t1: t, t2: t) => {
let pointwiseCombine = (fn, evaluationParams: evaluationParams, t1: t, t2: t) => {
// TODO: construct a function that we can easily sample from, to construct
// a RenderedDist. Use the xMin and xMax of the rendered shapes to tell the sampling function where to look.
// TODO: This should work for symbolic distributions too!
@ -176,7 +176,8 @@ module PointwiseCombination = {
) => {
switch (pointwiseOp) {
| `Add => pointwiseAdd(evaluationParams, t1, t2)
| `Multiply => pointwiseMultiply(evaluationParams, t1, t2)
| `Multiply => pointwiseCombine(( *. ),evaluationParams, t1, t2)
| `Exponentiate => pointwiseCombine(( *. ),evaluationParams, t1, t2)
};
};
};
@ -279,7 +280,7 @@ module Render = {
| `SymbolicDist(d) =>
Ok(
`RenderedDist(
SymbolicDist.T.toShape(1234, d),
SymbolicDist.T.toShape(evaluationParams.samplingInputs.shapeLength, d),
),
)
| `RenderedDist(_) as t => Ok(t) // already a rendered shape, we're done here

View File

@ -1,5 +1,5 @@
type algebraicOperation = [ | `Add | `Multiply | `Subtract | `Divide];
type pointwiseOperation = [ | `Add | `Multiply];
type algebraicOperation = [ | `Add | `Multiply | `Subtract | `Divide | `Exponentiate];
type pointwiseOperation = [ | `Add | `Multiply | `Exponentiate];
type scaleOperation = [ | `Multiply | `Exponentiate | `Log];
type distToFloatOperation = [
| `Pdf(float)

View File

@ -177,11 +177,12 @@ module MathAdtToDistDst = {
};
};
// Error("Dotwise exponentiation needs two operands")
let operationParser =
(
name: string,
args: result(array(ExpressionTypes.ExpressionTree.node), string),
) => {
):result(ExpressionTypes.ExpressionTree.node,string) => {
let toOkAlgebraic = r => Ok(`AlgebraicCombination(r));
let toOkPointwise = r => Ok(`PointwiseCombination(r));
let toOkTruncate = r => Ok(`Truncate(r));
@ -195,6 +196,8 @@ module MathAdtToDistDst = {
| ("subtract", _) => Error("Subtraction needs two operands")
| ("multiply", [|l, r|]) => toOkAlgebraic((`Multiply, l, r))
| ("multiply", _) => Error("Multiplication needs two operands")
| ("pow", [|l,r|]) => toOkAlgebraic((`Exponentiate, l, r))
| ("pow", _) => Error("Exponentiation needs two operands")
| ("dotMultiply", [|l, r|]) => toOkPointwise((`Multiply, l, r))
| ("dotMultiply", _) =>
Error("Dotwise multiplication needs two operands")
@ -203,7 +206,6 @@ module MathAdtToDistDst = {
Error("Dotwise addition needs two operands")
| ("divide", [|l, r|]) => toOkAlgebraic((`Divide, l, r))
| ("divide", _) => Error("Division needs two operands")
| ("pow", _) => Error("Exponentiation is not yet supported.")
| ("leftTruncate", [|d, `SymbolicDist(`Float(lc))|]) =>
toOkTruncate((Some(lc), None, d))
| ("leftTruncate", _) =>

View File

@ -7,6 +7,7 @@ module Algebraic = {
| `Add => (+.)
| `Subtract => (-.)
| `Multiply => ( *. )
| `Exponentiate => ( ** )
| `Divide => (/.);
let applyFn = (t, f1, f2) => {
@ -21,6 +22,7 @@ module Algebraic = {
| `Add => "+"
| `Subtract => "-"
| `Multiply => "*"
| `Exponentiate => ( "**" )
| `Divide => "/";
let format = (a, b, c) => b ++ " " ++ toString(a) ++ " " ++ c;

View File

@ -125,7 +125,7 @@ module Internals = {
let inputsToShape = (inputs: inputs) => {
MathJsParser.fromString(inputs.guesstimatorString)
|> E.R.bind(_, g => runProgram(inputs, g))
|> E.R.bind(_, r => E.A.last(r) |> E.O.toResult("No rendered lines"));
|> E.R.bind(_, r => E.A.last(r) |> E.O.toResult("No rendered lines") |> E.R.fmap(Shape.T.normalize));
};
let outputToDistPlus = (inputs: Inputs.inputs, shape: DistTypes.shape) => {