Added named paramaters to most GenericDist functions

This commit is contained in:
Ozzie Gooen 2022-03-31 09:19:27 -04:00
parent d61f521a0e
commit f2d03c8f11
4 changed files with 94 additions and 63 deletions

View File

@ -29,10 +29,14 @@ let normalize = (t: t) =>
| #SampleSet(_) => t | #SampleSet(_) => t
} }
let operationToFloat = (t, toPointSet: toPointSetFn, fnName) => { let operationToFloat = (
t,
~toPointSetFn: toPointSetFn,
~operation: Operation.distToFloatOperation,
) => {
let symbolicSolution = switch t { let symbolicSolution = switch t {
| #Symbolic(r) => | #Symbolic(r) =>
switch SymbolicDist.T.operate(fnName, r) { switch SymbolicDist.T.operate(operation, r) {
| Ok(f) => Some(f) | Ok(f) => Some(f)
| _ => None | _ => None
} }
@ -41,7 +45,7 @@ let operationToFloat = (t, toPointSet: toPointSetFn, fnName) => {
switch symbolicSolution { switch symbolicSolution {
| Some(r) => Ok(r) | Some(r) => Ok(r)
| None => toPointSet(t)->E.R2.fmap(PointSetDist.operate(fnName)) | None => toPointSetFn(t)->E.R2.fmap(PointSetDist.operate(operation))
} }
} }
@ -84,9 +88,10 @@ module Truncate = {
let run = ( let run = (
t: t, t: t,
toPointSet: toPointSetFn, ~toPointSetFn: toPointSetFn,
leftCutoff: option<float>, ~leftCutoff=None: option<float>,
rightCutoff: option<float>, ~rightCutoff=None: option<float>,
(),
): result<t, error> => { ): result<t, error> => {
let doesNotNeedCutoff = E.O.isNone(leftCutoff) && E.O.isNone(rightCutoff) let doesNotNeedCutoff = E.O.isNone(leftCutoff) && E.O.isNone(rightCutoff)
if doesNotNeedCutoff { if doesNotNeedCutoff {
@ -95,7 +100,7 @@ module Truncate = {
switch trySymbolicSimplification(leftCutoff, rightCutoff, t) { switch trySymbolicSimplification(leftCutoff, rightCutoff, t) {
| Some(r) => Ok(r) | Some(r) => Ok(r)
| None => | None =>
toPointSet(t)->E.R2.fmap(t => toPointSetFn(t)->E.R2.fmap(t =>
#PointSet(PointSetDist.T.truncate(leftCutoff, rightCutoff, t)) #PointSet(PointSetDist.T.truncate(leftCutoff, rightCutoff, t))
) )
} }
@ -168,20 +173,20 @@ module AlgebraicCombination = {
let run = ( let run = (
t1: t, t1: t,
toPointSet: toPointSetFn, ~toPointSetFn: toPointSetFn,
toSampleSet: toSampleSetFn, ~toSampleSetFn: toSampleSetFn,
algebraicOp, ~operation,
t2: t, ~t2: t,
): result<t, error> => { ): result<t, error> => {
switch tryAnalyticalSimplification(algebraicOp, t1, t2) { switch tryAnalyticalSimplification(operation, t1, t2) {
| Some(Ok(symbolicDist)) => Ok(#Symbolic(symbolicDist)) | Some(Ok(symbolicDist)) => Ok(#Symbolic(symbolicDist))
| Some(Error(e)) => Error(Other(e)) | Some(Error(e)) => Error(Other(e))
| None => | None =>
switch chooseConvolutionOrMonteCarlo(t1, t2) { switch chooseConvolutionOrMonteCarlo(t1, t2) {
| #CalculateWithMonteCarlo => | #CalculateWithMonteCarlo =>
runMonteCarlo(toSampleSet, algebraicOp, t1, t2)->E.R2.fmap(r => #SampleSet(r)) runMonteCarlo(toSampleSetFn, operation, t1, t2)->E.R2.fmap(r => #SampleSet(r))
| #CalculateWithConvolution => | #CalculateWithConvolution =>
runConvolution(toPointSet, algebraicOp, t1, t2)->E.R2.fmap(r => #PointSet(r)) runConvolution(toPointSetFn, operation, t1, t2)->E.R2.fmap(r => #PointSet(r))
} }
} }
} }
@ -190,11 +195,11 @@ module AlgebraicCombination = {
let algebraicCombination = AlgebraicCombination.run let algebraicCombination = AlgebraicCombination.run
//TODO: Add faster pointwiseCombine fn //TODO: Add faster pointwiseCombine fn
let pointwiseCombination = (t1: t, toPointSet: toPointSetFn, operation, t2: t): result< let pointwiseCombination = (t1: t, ~toPointSetFn: toPointSetFn, ~operation, ~t2: t): result<
t, t,
error, error,
> => { > => {
E.R.merge(toPointSet(t1), toPointSet(t2)) E.R.merge(toPointSetFn(t1), toPointSetFn(t2))
->E.R2.fmap(((t1, t2)) => ->E.R2.fmap(((t1, t2)) =>
PointSetDist.combinePointwise(GenericDist_Types.Operation.arithmeticToFn(operation), t1, t2) PointSetDist.combinePointwise(GenericDist_Types.Operation.arithmeticToFn(operation), t1, t2)
) )
@ -203,22 +208,22 @@ let pointwiseCombination = (t1: t, toPointSet: toPointSetFn, operation, t2: t):
let pointwiseCombinationFloat = ( let pointwiseCombinationFloat = (
t: t, t: t,
toPointSet: toPointSetFn, ~toPointSetFn: toPointSetFn,
operation: GenericDist_Types.Operation.arithmeticOperation, ~operation: GenericDist_Types.Operation.arithmeticOperation,
f: float, ~float: float,
): result<t, error> => { ): result<t, error> => {
let m = switch operation { let m = switch operation {
| #Add | #Subtract => Error(GenericDist_Types.DistributionVerticalShiftIsInvalid) | #Add | #Subtract => Error(GenericDist_Types.DistributionVerticalShiftIsInvalid)
| (#Multiply | #Divide | #Exponentiate | #Log) as operation => | (#Multiply | #Divide | #Exponentiate | #Log) as operation =>
toPointSet(t)->E.R2.fmap(t => { toPointSetFn(t)->E.R2.fmap(t => {
//TODO: Move to PointSet codebase //TODO: Move to PointSet codebase
let fn = (secondary, main) => Operation.Scale.toFn(operation, main, secondary) let fn = (secondary, main) => Operation.Scale.toFn(operation, main, secondary)
let integralSumCacheFn = Operation.Scale.toIntegralSumCacheFn(operation) let integralSumCacheFn = Operation.Scale.toIntegralSumCacheFn(operation)
let integralCacheFn = Operation.Scale.toIntegralCacheFn(operation) let integralCacheFn = Operation.Scale.toIntegralCacheFn(operation)
PointSetDist.T.mapY( PointSetDist.T.mapY(
~integralSumCacheFn=integralSumCacheFn(f), ~integralSumCacheFn=integralSumCacheFn(float),
~integralCacheFn=integralCacheFn(f), ~integralCacheFn=integralCacheFn(float),
~fn=fn(f), ~fn=fn(float),
t, t,
) )
}) })
@ -230,8 +235,8 @@ let pointwiseCombinationFloat = (
//Note: If the inputs are not normalized, this will return poor results. The weights probably refer to the post-normalized forms. It would be good to apply a catch to this. //Note: If the inputs are not normalized, this will return poor results. The weights probably refer to the post-normalized forms. It would be good to apply a catch to this.
let mixture = ( let mixture = (
values: array<(t, float)>, values: array<(t, float)>,
scaleMultiply: scaleMultiplyFn, ~scaleMultiplyFn: scaleMultiplyFn,
pointwiseAdd: pointwiseAddFn, ~pointwiseAddFn: pointwiseAddFn,
) => { ) => {
if E.A.length(values) == 0 { if E.A.length(values) == 0 {
Error(GenericDist_Types.Other("mixture must have at least 1 element")) Error(GenericDist_Types.Other("mixture must have at least 1 element"))
@ -239,13 +244,13 @@ let mixture = (
let totalWeight = values->E.A2.fmap(E.Tuple2.second)->E.A.Floats.sum let totalWeight = values->E.A2.fmap(E.Tuple2.second)->E.A.Floats.sum
let properlyWeightedValues = let properlyWeightedValues =
values values
->E.A2.fmap(((dist, weight)) => scaleMultiply(dist, weight /. totalWeight)) ->E.A2.fmap(((dist, weight)) => scaleMultiplyFn(dist, weight /. totalWeight))
->E.A.R.firstErrorOrOpen ->E.A.R.firstErrorOrOpen
properlyWeightedValues->E.R.bind(values => { properlyWeightedValues->E.R.bind(values => {
values values
|> Js.Array.sliceFrom(1) |> Js.Array.sliceFrom(1)
|> E.A.fold_left( |> E.A.fold_left(
(acc, x) => E.R.bind(acc, acc => pointwiseAdd(acc, x)), (acc, x) => E.R.bind(acc, acc => pointwiseAddFn(acc, x)),
Ok(E.A.unsafe_get(values, 0)), Ok(E.A.unsafe_get(values, 0)),
) )
}) })

View File

@ -13,32 +13,46 @@ let toString: t => string
let normalize: t => t let normalize: t => t
let operationToFloat: (t, toPointSetFn, Operation.distToFloatOperation) => result<float, error> let operationToFloat: (
t,
~toPointSetFn: toPointSetFn,
~operation: Operation.distToFloatOperation,
) => result<float, error>
let toPointSet: (t, int) => result<PointSetTypes.pointSetDist, error> let toPointSet: (t, int) => result<PointSetTypes.pointSetDist, error>
let truncate: (t, toPointSetFn, option<float>, option<float>) => result<t, error> let truncate: (
t,
~toPointSetFn: toPointSetFn,
~leftCutoff: option<float>=?,
~rightCutoff: option<float>=?,
unit,
) => result<t, error>
let algebraicCombination: ( let algebraicCombination: (
t, t,
toPointSetFn, ~toPointSetFn: toPointSetFn,
toSampleSetFn, ~toSampleSetFn: toSampleSetFn,
GenericDist_Types.Operation.arithmeticOperation, ~operation: GenericDist_Types.Operation.arithmeticOperation,
t, ~t2: t,
) => result<t, error> ) => result<t, error>
let pointwiseCombination: ( let pointwiseCombination: (
t, t,
toPointSetFn, ~toPointSetFn: toPointSetFn,
GenericDist_Types.Operation.arithmeticOperation, ~operation: GenericDist_Types.Operation.arithmeticOperation,
t, ~t2: t,
) => result<t, error> ) => result<t, error>
let pointwiseCombinationFloat: ( let pointwiseCombinationFloat: (
t, t,
toPointSetFn, ~toPointSetFn: toPointSetFn,
GenericDist_Types.Operation.arithmeticOperation, ~operation: GenericDist_Types.Operation.arithmeticOperation,
float, ~float: float,
) => result<t, error> ) => result<t, error>
let mixture: (array<(t, float)>, scaleMultiplyFn, pointwiseAddFn) => result<t, error> let mixture: (
array<(t, float)>,
~scaleMultiplyFn: scaleMultiplyFn,
~pointwiseAddFn: pointwiseAddFn,
) => result<t, error>

View File

@ -48,11 +48,17 @@ let fromResult = (r: result<outputType, error>): outputType =>
| Error(e) => #GenDistError(e) | Error(e) => #GenDistError(e)
} }
//This is used to catch errors in other switch statements.
let _errorMap = (o: outputType): error =>
switch o {
| #GenDistError(r) => r
| _ => Unreachable
}
let outputToDistResult = (o: outputType): result<genericDist, error> => let outputToDistResult = (o: outputType): result<genericDist, error> =>
switch o { switch o {
| #Dist(r) => Ok(r) | #Dist(r) => Ok(r)
| #GenDistError(r) => Error(r) | r => Error(_errorMap(r))
| _ => Error(Unreachable)
} }
let rec run = (extra, fnName: operation): outputType => { let rec run = (extra, fnName: operation): outputType => {
@ -62,19 +68,17 @@ let rec run = (extra, fnName: operation): outputType => {
run(extra, fnName) run(extra, fnName)
} }
let toPointSet = r => { let toPointSetFn = r => {
switch reCall(~fnName=#fromDist(#toDist(#toPointSet), r), ()) { switch reCall(~fnName=#fromDist(#toDist(#toPointSet), r), ()) {
| #Dist(#PointSet(p)) => Ok(p) | #Dist(#PointSet(p)) => Ok(p)
| #GenDistError(r) => Error(r) | r => Error(_errorMap(r))
| _ => Error(Unreachable)
} }
} }
let toSampleSet = r => { let toSampleSetFn = r => {
switch reCall(~fnName=#fromDist(#toDist(#toSampleSet(sampleCount)), r), ()) { switch reCall(~fnName=#fromDist(#toDist(#toSampleSet(sampleCount)), r), ()) {
| #Dist(#SampleSet(p)) => Ok(p) | #Dist(#SampleSet(p)) => Ok(p)
| #GenDistError(r) => Error(r) | r => Error(_errorMap(r))
| _ => Error(Unreachable)
} }
} }
@ -93,42 +97,50 @@ let rec run = (extra, fnName: operation): outputType => {
let fromDistFn = (subFnName: GenericDist_Types.Operation.fromDist, dist: genericDist) => let fromDistFn = (subFnName: GenericDist_Types.Operation.fromDist, dist: genericDist) =>
switch subFnName { switch subFnName {
| #toFloat(fnName) => | #toFloat(fnName) =>
GenericDist.operationToFloat(dist, toPointSet, fnName)->E.R2.fmap(r => #Float(r))->fromResult GenericDist.operationToFloat(dist, ~toPointSetFn, ~operation=fnName)
->E.R2.fmap(r => #Float(r))
->fromResult
| #toString => dist->GenericDist.toString->(r => #String(r)) | #toString => dist->GenericDist.toString->(r => #String(r))
| #toDist(#consoleLog) => { | #toDist(#inspect) => {
Js.log2("Console log requested: ", dist) Js.log2("Console log requested: ", dist)
#Dist(dist) #Dist(dist)
} }
| #toDist(#normalize) => dist->GenericDist.normalize->(r => #Dist(r)) | #toDist(#normalize) => dist->GenericDist.normalize->(r => #Dist(r))
| #toDist(#truncate(left, right)) => | #toDist(#truncate(leftCutoff, rightCutoff)) =>
dist->GenericDist.truncate(toPointSet, left, right)->E.R2.fmap(r => #Dist(r))->fromResult GenericDist.truncate(~toPointSetFn, ~leftCutoff, ~rightCutoff, dist, ())
->E.R2.fmap(r => #Dist(r))
->fromResult
| #toDist(#toPointSet) => | #toDist(#toPointSet) =>
dist->GenericDist.toPointSet(xyPointLength)->E.R2.fmap(r => #Dist(#PointSet(r)))->fromResult dist->GenericDist.toPointSet(xyPointLength)->E.R2.fmap(r => #Dist(#PointSet(r)))->fromResult
| #toDist(#toSampleSet(n)) => | #toDist(#toSampleSet(n)) =>
dist->GenericDist.sampleN(n)->E.R2.fmap(r => #Dist(#SampleSet(r)))->fromResult dist->GenericDist.sampleN(n)->E.R2.fmap(r => #Dist(#SampleSet(r)))->fromResult
| #toDistCombination(#Algebraic, _, #Float(_)) => #GenDistError(NotYetImplemented) | #toDistCombination(#Algebraic, _, #Float(_)) => #GenDistError(NotYetImplemented)
| #toDistCombination(#Algebraic, operation, #Dist(dist2)) => | #toDistCombination(#Algebraic, operation, #Dist(t2)) =>
dist dist
->GenericDist.algebraicCombination(toPointSet, toSampleSet, operation, dist2) ->GenericDist.algebraicCombination(~toPointSetFn, ~toSampleSetFn, ~operation, ~t2)
->E.R2.fmap(r => #Dist(r)) ->E.R2.fmap(r => #Dist(r))
->fromResult ->fromResult
| #toDistCombination(#Pointwise, operation, #Dist(dist2)) => | #toDistCombination(#Pointwise, operation, #Dist(t2)) =>
dist dist
->GenericDist.pointwiseCombination(toPointSet, operation, dist2) ->GenericDist.pointwiseCombination(~toPointSetFn, ~operation, ~t2)
->E.R2.fmap(r => #Dist(r)) ->E.R2.fmap(r => #Dist(r))
->fromResult ->fromResult
| #toDistCombination(#Pointwise, operation, #Float(f)) => | #toDistCombination(#Pointwise, operation, #Float(float)) =>
dist dist
->GenericDist.pointwiseCombinationFloat(toPointSet, operation, f) ->GenericDist.pointwiseCombinationFloat(~toPointSetFn, ~operation, ~float)
->E.R2.fmap(r => #Dist(r)) ->E.R2.fmap(r => #Dist(r))
->fromResult ->fromResult
} }
switch fnName { switch fnName {
| #fromDist(subFnName, dist) => fromDistFn(subFnName, dist) | #fromDist(subFnName, dist) => fromDistFn(subFnName, dist)
| #fromFloat(subFnName, float) => reCall(~fnName=#fromDist(subFnName, GenericDist.fromFloat(float)), ()) | #fromFloat(subFnName, float) =>
reCall(~fnName=#fromDist(subFnName, GenericDist.fromFloat(float)), ())
| #mixture(dists) => | #mixture(dists) =>
dists->GenericDist.mixture(scaleMultiply, pointwiseAdd)->E.R2.fmap(r => #Dist(r))->fromResult dists
->GenericDist.mixture(~scaleMultiplyFn=scaleMultiply, ~pointwiseAddFn=pointwiseAdd)
->E.R2.fmap(r => #Dist(r))
->fromResult
} }
} }
@ -148,4 +160,4 @@ let outputMap = (
| (#fromFloat(_), _) => Error(Other("Expected float, got something else")) | (#fromFloat(_), _) => Error(Other("Expected float, got something else"))
} }
newFnCall->E.R2.fmap(r => run(extra, r))->fromResult newFnCall->E.R2.fmap(r => run(extra, r))->fromResult
} }

View File

@ -48,7 +48,7 @@ module Operation = {
| #toPointSet | #toPointSet
| #toSampleSet(int) | #toSampleSet(int)
| #truncate(option<float>, option<float>) | #truncate(option<float>, option<float>)
| #consoleLog | #inspect
] ]
type toFloatArray = [ type toFloatArray = [
@ -85,7 +85,7 @@ module Operation = {
| #toDist(#toPointSet) => `toPointSet` | #toDist(#toPointSet) => `toPointSet`
| #toDist(#toSampleSet(r)) => `toSampleSet(${E.I.toString(r)})` | #toDist(#toSampleSet(r)) => `toSampleSet(${E.I.toString(r)})`
| #toDist(#truncate(_, _)) => `truncate` | #toDist(#truncate(_, _)) => `truncate`
| #toDist(#consoleLog) => `consoleLog` | #toDist(#inspect) => `inspect`
| #toString => `toString` | #toString => `toString`
| #toDistCombination(#Algebraic, _, _) => `algebraic` | #toDistCombination(#Algebraic, _, _) => `algebraic`
| #toDistCombination(#Pointwise, _, _) => `pointwise` | #toDistCombination(#Pointwise, _, _) => `pointwise`