Remove NaN from pointwise operations

This commit is contained in:
Sam Nolan 2022-04-23 14:09:06 -04:00
parent 98bf4f81c7
commit c7e601e15b
17 changed files with 168 additions and 148 deletions

View File

@ -50,7 +50,11 @@ module Internals = {
let dist1 = dist1'->DistributionTypes.Symbolic
let dist2 = dist2'->DistributionTypes.Symbolic
let received =
distOp(dist1, dist2)->E.R2.fmap(mean)->E.R2.fmap(run)->E.R2.fmap(toFloat)->E.R.toExn("Expected float", _)
distOp(dist1, dist2)
->E.R2.fmap(mean)
->E.R2.fmap(run)
->E.R2.fmap(toFloat)
->E.R.toExn("Expected float", _)
let expected = floatOp(runMean(dist1), runMean(dist2))
switch received {
| None => expectImpossiblePath(description)

View File

@ -97,12 +97,6 @@ describe("eval on distribution functions", () => {
testEval("log10(uniform(5,8))", "Ok(Sample Set Distribution)")
})
describe("dotLog", () => {
testEval("dotLog(normal(5,2), 3)", "Ok(Point Set Distribution)")
testEval("dotLog(normal(5,2), 3)", "Ok(Point Set Distribution)")
testEval("dotLog(normal(5,2), normal(10,1))", "Ok(Point Set Distribution)")
})
describe("dotAdd", () => {
testEval("dotAdd(normal(5,2), lognormal(10,2))", "Ok(Point Set Distribution)")
testEval("dotAdd(normal(5,2), 3)", "Ok(Point Set Distribution)")

View File

@ -2,6 +2,7 @@
module.exports = {
preset: "ts-jest",
testEnvironment: "node",
bail: true,
setupFilesAfterEnv: [
"<rootdir>/../../node_modules/bisect_ppx/src/runtime/js/jest.bs.js",
],

View File

@ -160,14 +160,14 @@ let rec run = (~env, functionCallInfo: functionCallInfo): outputType => {
->GenericDist.algebraicCombination(~toPointSetFn, ~toSampleSetFn, ~arithmeticOperation, ~t2)
->E.R2.fmap(r => Dist(r))
->OutputLocal.fromResult
| ToDistCombination(Pointwise, arithmeticOperation, #Dist(t2)) =>
| ToDistCombination(Pointwise, algebraicCombination, #Dist(t2)) =>
dist
->GenericDist.pointwiseCombination(~toPointSetFn, ~arithmeticOperation, ~t2)
->GenericDist.pointwiseCombination(~toPointSetFn, ~algebraicCombination, ~t2)
->E.R2.fmap(r => Dist(r))
->OutputLocal.fromResult
| ToDistCombination(Pointwise, arithmeticOperation, #Float(f)) =>
| ToDistCombination(Pointwise, algebraicCombination, #Float(f)) =>
dist
->GenericDist.pointwiseCombinationFloat(~toPointSetFn, ~arithmeticOperation, ~f)
->GenericDist.pointwiseCombinationFloat(~toPointSetFn, ~algebraicCombination, ~f)
->E.R2.fmap(r => Dist(r))
->OutputLocal.fromResult
}

View File

@ -46,30 +46,14 @@ module Error = {
}
@genType
module Operation = {
module DistributionOperation = {
@genType
type pointsetXSelection = [#Linear | #ByWeight]
type direction =
| Algebraic
| Pointwise
type arithmeticOperation = [
| #Add
| #Multiply
| #Subtract
| #Divide
| #Power
| #Logarithm
]
let arithmeticToFn = (arithmetic: arithmeticOperation) =>
switch arithmetic {
| #Add => \"+."
| #Multiply => \"*."
| #Subtract => \"-."
| #Power => \"**"
| #Divide => \"/."
| #Logarithm => (a, b) => log(a) /. log(b)
}
type toFloat = [
| #Cdf(float)
| #Inv(float)
@ -78,11 +62,6 @@ module Operation = {
| #Sample
]
@genType
type pointsetXSelection = [#Linear | #ByWeight]
}
module DistributionOperation = {
type toDist =
| Normalize
| ToPointSet
@ -99,13 +78,9 @@ module DistributionOperation = {
| ToSparkline(int)
type fromDist =
| ToFloat(Operation.toFloat)
| ToFloat(toFloat)
| ToDist(toDist)
| ToDistCombination(
Operation.direction,
Operation.arithmeticOperation,
[#Dist(genericDist) | #Float(float)],
)
| ToDistCombination(direction, Operation.Algebraic.t, [#Dist(genericDist) | #Float(float)])
| ToString(toString)
| ToBool(toBool)

View File

@ -68,7 +68,7 @@ let toPointSet = (
t,
~xyPointLength,
~sampleCount,
~xSelection: DistributionTypes.Operation.pointsetXSelection=#ByWeight,
~xSelection: DistributionTypes.DistributionOperation.pointsetXSelection=#ByWeight,
(),
): result<PointSetTypes.pointSetDist, error> => {
switch (t: t) {
@ -148,7 +148,7 @@ let truncate = Truncate.run
*/
module AlgebraicCombination = {
let tryAnalyticalSimplification = (
arithmeticOperation: DistributionTypes.Operation.arithmeticOperation,
arithmeticOperation: Operation.algebraicOperation,
t1: t,
t2: t,
): option<result<SymbolicDistTypes.symbolicDist, Operation.Error.invalidOperationError>> =>
@ -174,7 +174,7 @@ module AlgebraicCombination = {
let runMonteCarlo = (
toSampleSet: toSampleSetFn,
arithmeticOperation: DistributionTypes.Operation.arithmeticOperation,
arithmeticOperation: Operation.algebraicOperation,
t1: t,
t2: t,
): result<t, error> => {
@ -241,27 +241,23 @@ let algebraicCombination = AlgebraicCombination.run
let pointwiseCombination = (
t1: t,
~toPointSetFn: toPointSetFn,
~arithmeticOperation,
~algebraicCombination: Operation.algebraicOperation,
~t2: t,
): result<t, error> => {
E.R.merge(toPointSetFn(t1), toPointSetFn(t2))
->E.R2.fmap(((t1, t2)) =>
PointSetDist.combinePointwise(
DistributionTypes.Operation.arithmeticToFn(arithmeticOperation),
t1,
t2,
)
E.R.merge(toPointSetFn(t1), toPointSetFn(t2))->E.R.bind(((t1, t2)) =>
PointSetDist.combinePointwise(Operation.Algebraic.toFn(algebraicCombination), t1, t2)
->E.R2.fmap(r => DistributionTypes.PointSet(r))
->E.R2.errMap(err => DistributionTypes.OperationError(err))
)
->E.R2.fmap(r => DistributionTypes.PointSet(r))
}
let pointwiseCombinationFloat = (
t: t,
~toPointSetFn: toPointSetFn,
~arithmeticOperation: DistributionTypes.Operation.arithmeticOperation,
~algebraicCombination: Operation.algebraicOperation,
~f: float,
): result<t, error> => {
let m = switch arithmeticOperation {
let m = switch algebraicCombination {
| #Add | #Subtract => Error(DistributionTypes.DistributionVerticalShiftIsInvalid)
| (#Multiply | #Divide | #Power | #Logarithm) as arithmeticOperation =>
toPointSetFn(t)->E.R.bind(t => {

View File

@ -28,7 +28,7 @@ let toPointSet: (
t,
~xyPointLength: int,
~sampleCount: int,
~xSelection: DistributionTypes.Operation.pointsetXSelection=?,
~xSelection: DistributionTypes.DistributionOperation.pointsetXSelection=?,
unit,
) => result<PointSetTypes.pointSetDist, error>
let toSparkline: (t, ~sampleCount: int, ~bucketCount: int=?, unit) => result<string, error>
@ -45,21 +45,21 @@ let algebraicCombination: (
t,
~toPointSetFn: toPointSetFn,
~toSampleSetFn: toSampleSetFn,
~arithmeticOperation: DistributionTypes.Operation.arithmeticOperation,
~arithmeticOperation: Operation.algebraicOperation,
~t2: t,
) => result<t, error>
let pointwiseCombination: (
t,
~toPointSetFn: toPointSetFn,
~arithmeticOperation: DistributionTypes.Operation.arithmeticOperation,
~algebraicCombination: Operation.algebraicOperation,
~t2: t,
) => result<t, error>
let pointwiseCombinationFloat: (
t,
~toPointSetFn: toPointSetFn,
~arithmeticOperation: DistributionTypes.Operation.arithmeticOperation,
~algebraicCombination: Operation.algebraicOperation,
~f: float,
) => result<t, error>

View File

@ -243,10 +243,13 @@ let combineShapesContinuousDiscrete = (
outXYShapes
|> E.A.fmap(XYShape.T.fromZippedArray)
|> E.A.fold_left(
XYShape.PointwiseCombination.combine(
\"+.",
XYShape.XtoY.continuousInterpolator(#Linear, #UseZero),
),
(acc, x) =>
XYShape.PointwiseCombination.combine(
(a, b) => Ok(a +. b),
XYShape.XtoY.continuousInterpolator(#Linear, #UseZero),
acc,
x,
)->E.R.toExn("Error, unexpected failure", _),
XYShape.T.empty,
)
}

View File

@ -88,10 +88,10 @@ let stepwiseToLinear = (t: t): t =>
let combinePointwise = (
~integralSumCachesFn=(_, _) => None,
~distributionType: PointSetTypes.distributionType=#PDF,
fn: (float, float) => float,
fn: (float, float) => result<float, Operation.Error.invalidOperationError>,
t1: PointSetTypes.continuousShape,
t2: PointSetTypes.continuousShape,
): PointSetTypes.continuousShape => {
): result<PointSetTypes.continuousShape, 'e> => {
// If we're adding the distributions, and we know the total of each, then we
// can just sum them up. Otherwise, all bets are off.
let combinedIntegralSum = Common.combineIntegralSums(
@ -119,9 +119,8 @@ let combinePointwise = (
let interpolator = XYShape.XtoY.continuousInterpolator(t1.interpolation, extrapolation)
make(
~integralSumCache=combinedIntegralSum,
XYShape.PointwiseCombination.combine(fn, interpolator, t1.xyShape, t2.xyShape),
XYShape.PointwiseCombination.combine(fn, interpolator, t1.xyShape, t2.xyShape)->E.R2.fmap(x =>
make(~integralSumCache=combinedIntegralSum, x)
)
}
@ -140,11 +139,25 @@ let updateIntegralSumCache = (integralSumCache, t: t): t => {
let updateIntegralCache = (integralCache, t: t): t => {...t, integralCache: integralCache}
let sum = (
~integralSumCachesFn: (float, float) => option<float>=(_, _) => None,
continuousShapes,
): t =>
continuousShapes |> E.A.fold_left(
(x, y) =>
combinePointwise(~integralSumCachesFn, (a, b) => Ok(a +. b), x, y)->E.R.toExn(
"Addition should never fail",
_,
),
empty,
)
let reduce = (
~integralSumCachesFn: (float, float) => option<float>=(_, _) => None,
fn,
fn: (float, float) => result<float, 'e>,
continuousShapes,
) => continuousShapes |> E.A.fold_left(combinePointwise(~integralSumCachesFn, fn), empty)
): result<t, 'e> =>
continuousShapes |> E.A.R.foldM(combinePointwise(~integralSumCachesFn, fn), empty)
let mapYResult = (
~integralSumCacheFn=_ => None,

View File

@ -49,11 +49,11 @@ let combinePointwise = (
make(
~integralSumCache=combinedIntegralSum,
XYShape.PointwiseCombination.combine(
\"+.",
(a, b) => Ok(a +. b),
XYShape.XtoY.discreteInterpolator,
t1.xyShape,
t2.xyShape,
),
)->E.R.toExn("Addition operation should never fail", _),
)
}

View File

@ -146,8 +146,7 @@ module T = Dist({
let discreteIntegral = Continuous.stepwiseToLinear(Discrete.T.Integral.get(t.discrete))
Continuous.make(
XYShape.PointwiseCombination.combine(
\"+.",
XYShape.PointwiseCombination.addCombine(
XYShape.XtoY.continuousInterpolator(#Linear, #UseOutermostPoints),
Continuous.getShape(continuousIntegral),
Continuous.getShape(discreteIntegral),
@ -280,7 +279,7 @@ let combineAlgebraically = (op: Operation.convolutionOperation, t1: t, t2: t): t
let ccConvResult = Continuous.combineAlgebraically(op, t1.continuous, t2.continuous)
let dcConvResult = Continuous.combineAlgebraicallyWithDiscrete(op, t2.continuous, t1.discrete)
let cdConvResult = Continuous.combineAlgebraicallyWithDiscrete(op, t1.continuous, t2.discrete)
let continuousConvResult = Continuous.reduce(\"+.", [ccConvResult, dcConvResult, cdConvResult])
let continuousConvResult = Continuous.sum([ccConvResult, dcConvResult, cdConvResult])
// ... finally, discrete (*) discrete => discrete, obviously:
let discreteConvResult = Discrete.combineAlgebraically(op, t1.discrete, t2.discrete)
@ -302,10 +301,10 @@ let combineAlgebraically = (op: Operation.convolutionOperation, t1: t, t2: t): t
let combinePointwise = (
~integralSumCachesFn=(_, _) => None,
~integralCachesFn=(_, _) => None,
fn,
fn: (float, float) => result<float, 'e>,
t1: t,
t2: t,
): t => {
): result<t, 'e> => {
let reducedDiscrete =
[t1, t2] |> E.A.fmap(toDiscrete) |> E.A.O.concatSomes |> Discrete.reduce(~integralSumCachesFn)
@ -326,11 +325,12 @@ let combinePointwise = (
t1.integralCache,
t2.integralCache,
)
make(
~integralSumCache=combinedIntegralSum,
~integralCache=combinedIntegral,
~discrete=reducedDiscrete,
~continuous=reducedContinuous,
reducedContinuous->E.R2.fmap(continuous =>
make(
~integralSumCache=combinedIntegralSum,
~integralCache=combinedIntegral,
~discrete=reducedDiscrete,
~continuous,
)
)
}

View File

@ -60,19 +60,28 @@ let combinePointwise = (
PointSetTypes.continuousShape,
PointSetTypes.continuousShape,
) => option<PointSetTypes.continuousShape>=(_, _) => None,
fn,
fn: (float, float) => result<float, Operation.Error.invalidOperationError>,
t1: t,
t2: t,
) =>
): result<PointSetTypes.pointSetDist, Operation.Error.invalidOperationError> =>
switch (t1, t2) {
| (Continuous(m1), Continuous(m2)) =>
PointSetTypes.Continuous(Continuous.combinePointwise(~integralSumCachesFn, fn, m1, m2))
Continuous.combinePointwise(
~integralSumCachesFn,
fn,
m1,
m2,
)->E.R2.fmap(x => PointSetTypes.Continuous(x))
| (Discrete(m1), Discrete(m2)) =>
PointSetTypes.Discrete(Discrete.combinePointwise(~integralSumCachesFn, m1, m2))
Ok(PointSetTypes.Discrete(Discrete.combinePointwise(~integralSumCachesFn, m1, m2)))
| (m1, m2) =>
PointSetTypes.Mixed(
Mixed.combinePointwise(~integralSumCachesFn, ~integralCachesFn, fn, toMixed(m1), toMixed(m2)),
)
Mixed.combinePointwise(
~integralSumCachesFn,
~integralCachesFn,
fn,
toMixed(m1),
toMixed(m2),
)->E.R2.fmap(x => PointSetTypes.Mixed(x))
}
module T = Dist({

View File

@ -24,7 +24,6 @@ module Helpers = {
| "dotPow" => #Power
| "multiply" => #Multiply
| "dotMultiply" => #Multiply
| "dotLog" => #Logarithm
| _ => #Multiply
}
@ -41,7 +40,7 @@ module Helpers = {
}
let toFloatFn = (
fnCall: DistributionTypes.Operation.toFloat,
fnCall: DistributionTypes.DistributionOperation.toFloat,
dist: DistributionTypes.genericDist,
) => {
FromDist(DistributionTypes.DistributionOperation.ToFloat(fnCall), dist)
@ -243,15 +242,12 @@ let dispatchToGenericOutput = (call: ExpressionValue.functionCall): option<
| "dotMultiply"
| "dotSubtract"
| "dotDivide"
| "dotPow"
| "dotLog") as arithmetic,
| "dotPow") as arithmetic,
[_, _] as args,
) =>
Helpers.catchAndConvertTwoArgsToDists(args)->E.O2.fmap(((fst, snd)) =>
Helpers.twoDiststoDistFn(Pointwise, arithmetic, fst, snd)
)
| ("dotLog", [EvDistribution(a)]) =>
Helpers.twoDiststoDistFn(Pointwise, "dotLog", a, GenericDist.fromFloat(Math.e))->Some
| ("dotExp", [EvDistribution(a)]) =>
Helpers.twoDiststoDistFn(Pointwise, "dotPow", GenericDist.fromFloat(Math.e), a)->Some
| _ => None

View File

@ -192,6 +192,7 @@ module R = {
| Ok(f) => fmap(f, a)
| Error(err) => Error(err)
}
// (a1 -> a2 -> r) -> m a1 -> m a2 -> m r // not in Rationale
let liftM2: (('a, 'b) => 'c, result<'a, 'd>, result<'b, 'd>) => result<'c, 'd> = (op, xR, yR) => {
ap'(fmap(op, xR), yR)
@ -444,6 +445,31 @@ module A = {
bringErrorUp |> Belt.Result.map(_, forceOpen)
}
let filterOk = (x: array<result<'a, 'b>>): array<'a> => fmap(R.toOption, x)->O.concatSomes
let forM = (x: array<'a>, fn: 'a => result<'b, 'c>): result<array<'b>, 'c> =>
firstErrorOrOpen(fmap(fn, x))
let foldM = (fn: ('c, 'a) => result<'b, 'e>, init: 'c, x: array<'a>): result<'c, 'e> => {
let acc = ref(init)
let final = ref(Ok())
let break = ref(false)
let i = ref(0)
while break.contents != true && i.contents < length(x) {
switch fn(acc.contents, x[i.contents]) {
| Ok(r) => acc := r
| Error(err) => {
final := Error(err)
break := true
}
}
i := i.contents + 1
}
switch final.contents {
| Ok(_) => Ok(acc.contents)
| Error(err) => Error(err)
}
}
}
module Sorted = {

View File

@ -51,6 +51,31 @@ module Error = {
}
}
let power = (a: float, b: float): result<float, Error.invalidOperationError> =>
if a >= 0.0 {
Ok(a ** b)
} else {
Error(ComplexNumberError)
}
let divide = (a: float, b: float): result<float, Error.invalidOperationError> =>
if b != 0.0 {
Ok(a /. b)
} else {
Error(DivisionByZeroError)
}
let logarithm = (a: float, b: float): result<float, Error.invalidOperationError> =>
if b == 1. {
Error(DivisionByZeroError)
} else if b == 0. {
Ok(0.)
} else if a > 0.0 && b > 0.0 {
Ok(log(a) /. log(b))
} else {
Error(ComplexNumberError)
}
module Algebraic = {
type t = algebraicOperation
let toFn: (t, float, float) => result<float, Error.invalidOperationError> = (x, a, b) =>
@ -58,26 +83,9 @@ module Algebraic = {
| #Add => Ok(a +. b)
| #Subtract => Ok(a -. b)
| #Multiply => Ok(a *. b)
| #Power =>
if a >= 0.0 {
Ok(a ** b)
} else {
Error(ComplexNumberError)
}
| #Divide =>
if b != 0.0 {
Ok(a /. b)
} else {
Error(DivisionByZeroError)
}
| #Logarithm =>
if b == 1. {
Error(DivisionByZeroError)
} else if a > 0.0 && b > 0.0 {
Ok(log(a) /. log(b))
} else {
Error(ComplexNumberError)
}
| #Power => power(a, b)
| #Divide => divide(a, b)
| #Logarithm => logarithm(a, b)
}
let toString = x =>
@ -124,24 +132,9 @@ module Scale = {
let toFn = (x: t, a: float, b: float): result<float, Error.invalidOperationError> =>
switch x {
| #Multiply => Ok(a *. b)
| #Divide =>
if b != 0.0 {
Ok(a /. b)
} else {
Error(DivisionByZeroError)
}
| #Power =>
if a > 0.0 {
Ok(a ** b)
} else {
Error(ComplexNumberError)
}
| #Logarithm =>
if a > 0.0 && b > 0.0 {
Ok(log(a) /. log(b))
} else {
Error(DivisionByZeroError)
}
| #Divide => divide(a, b)
| #Power => power(a, b)
| #Logarithm => logarithm(a, b)
}
let format = (operation: t, value, scaleBy) =>

View File

@ -233,7 +233,12 @@ module Zipped = {
module PointwiseCombination = {
// t1Interpolator and t2Interpolator are functions from XYShape.XtoY, e.g. linearBetweenPointsExtrapolateFlat.
let combine: ((float, float) => float, interpolator, T.t, T.t) => T.t = %raw(`
let combine: (
(float, float) => result<float, Operation.Error.invalidOperationError>,
interpolator,
T.t,
T.t,
) => result<T.t, Operation.Error.invalidOperationError> = %raw(`
// This function combines two xyShapes by looping through both of them simultaneously.
// It always moves on to the next smallest x, whether that's in the first or second input's xs,
// and interpolates the value on the other side, thus accumulating xs and ys.
@ -281,13 +286,28 @@ module PointwiseCombination = {
}
outX.push(x);
outY.push(fn(ya, yb));
// Here I check whether the operation was a success. If it was
// keep going. Otherwise, stop and throw the error back to user
let newY = fn(ya, yb);
if(newY.TAG === 0){
outY.push(newY._0);
}
else {
return newY;
}
}
return {xs: outX, ys: outY};
return {TAG: 0, _0: {xs: outX, ys: outY}, [Symbol.for("name")]: "Ok"};
}
`)
let addCombine = (interpolator: interpolator, t1: T.t, t2: T.t): T.t =>
combine((a, b) => Ok(a +. b), interpolator, t1, t2)->E.R.toExn(
"Add operation should never fail",
_,
)
let combineEvenXs = (~fn, ~xToYSelection, sampleCount, t1: T.t, t2: T.t) =>
switch (E.A.length(t1.xs), E.A.length(t2.xs)) {
| (0, 0) => T.empty

View File

@ -255,16 +255,6 @@ dist2 = triangular(1,2,3)
dist1 .^ dist2`}
/>
### Pointwise logarithm
TODO: write about the semantics and the case handling re scalar vs. dist and log base.
<SquiggleEditor
initialSquiggleString={`dist1 = 1 to 10
dist2 = triangular(1,2,3)
dotLog(dist1, dist2)`}
/>
## Standard functions on distributions
### Probability density function