Added Truncate to sci.res

This commit is contained in:
Ozzie Gooen 2022-03-26 22:06:19 -04:00
parent c5afb2d867
commit 2ec1bfd068

View File

@ -1,3 +1,5 @@
//TODO: multimodal, add interface, split up a little bit, test somehow, track performance, refactor sampleSet, refactor ASTEvaluator.res.
type error = type error =
| NeedsPointSetConversion | NeedsPointSetConversion
| InputsNeedPointSetConversion | InputsNeedPointSetConversion
@ -18,7 +20,7 @@ module OperationType = {
| #Pointwise | #Pointwise
] ]
type combination = [ type arithmeticOperation = [
| #Add | #Add
| #Multiply | #Multiply
| #Subtract | #Subtract
@ -27,8 +29,8 @@ module OperationType = {
| #Log | #Log
] ]
let combinationToFn = (combination: combination) => let arithmeticToFn = (arithmetic: arithmeticOperation) =>
switch combination { switch arithmetic {
| #Add => \"+." | #Add => \"+."
| #Multiply => \"*." | #Multiply => \"*."
| #Subtract => \"-." | #Subtract => \"-."
@ -49,6 +51,7 @@ module OperationType = {
| #normalize | #normalize
| #toPointSet | #toPointSet
| #toSampleSet(int) | #toSampleSet(int)
| #truncate(option<float>, option<float>)
] ]
type toFloatArray = [ type toFloatArray = [
@ -58,7 +61,7 @@ module OperationType = {
type t = [ type t = [
| #toFloat(toFloat) | #toFloat(toFloat)
| #toDist(toDist) | #toDist(toDist)
| #toDistCombination(direction, combination, [#Dist(genericDist) | #Float(float)]) | #toDistCombination(direction, arithmeticOperation, [#Dist(genericDist) | #Float(float)])
] ]
} }
@ -76,6 +79,14 @@ module T = {
} }
} }
let normalize = (t: t) => {
switch t {
| #PointSet(r) => #PointSet(PointSetDist.T.normalize(r))
| #Symbolic(_) => t
| #SampleSet(_) => t
}
}
let toFloat = (toPointSet: toPointSetFn, fnName, t: genericDist): result<float, error> => { let toFloat = (toPointSet: toPointSetFn, fnName, t: genericDist): result<float, error> => {
switch t { switch t {
| #Symbolic(r) if Belt.Result.isOk(SymbolicDist.T.operate(fnName, r)) => | #Symbolic(r) if Belt.Result.isOk(SymbolicDist.T.operate(fnName, r)) =>
@ -117,10 +128,46 @@ module T = {
} }
} }
module Truncate = {
let trySymbolicSimplification = (leftCutoff, rightCutoff, t): option<t> =>
switch (leftCutoff, rightCutoff, t) {
| (None, None, _) => None
| (lc, rc, #Symbolic(#Uniform(u))) if lc < rc =>
Some(#Symbolic(#Uniform(SymbolicDist.Uniform.truncate(lc, rc, u))))
| _ => None
}
let run = (
toPointSet: toPointSetFn,
leftCutoff: option<float>,
rightCutoff: option<float>,
t: t,
): result<t, error> => {
let doesNotNeedCutoff = E.O.isNone(leftCutoff) && E.O.isNone(rightCutoff)
if doesNotNeedCutoff {
Ok(t)
} else {
switch trySymbolicSimplification(leftCutoff, rightCutoff, t) {
| Some(r) => Ok(r)
| None =>
toPointSet(t) |> E.R.fmap(t =>
#PointSet(PointSetDist.T.truncate(leftCutoff, rightCutoff, t))
)
}
}
}
}
/* Given two random variables A and B, this returns the distribution
of a new variable that is the result of the operation on A and B.
For instance, normal(0, 1) + normal(1, 1) -> normal(1, 2).
In general, this is implemented via convolution. */
module AlgebraicCombination = { module AlgebraicCombination = {
let tryAnalyticalSimplification = (operation: OperationType.combination, t1: t, t2: t): option< let tryAnalyticalSimplification = (
result<SymbolicDistTypes.symbolicDist, string>, operation: OperationType.arithmeticOperation,
> => t1: t,
t2: t,
): option<result<SymbolicDistTypes.symbolicDist, string>> =>
switch (operation, t1, t2) { switch (operation, t1, t2) {
| (operation, #Symbolic(d1), #Symbolic(d2)) => | (operation, #Symbolic(d1), #Symbolic(d2)) =>
switch SymbolicDist.T.tryAnalyticalSimplification(d1, d2, operation) { switch SymbolicDist.T.tryAnalyticalSimplification(d1, d2, operation) {
@ -133,7 +180,7 @@ module T = {
let runConvolution = ( let runConvolution = (
toPointSet: toPointSetFn, toPointSet: toPointSetFn,
operation: OperationType.combination, operation: OperationType.arithmeticOperation,
t1: t, t1: t,
t2: t, t2: t,
) => ) =>
@ -143,7 +190,7 @@ module T = {
let runMonteCarlo = ( let runMonteCarlo = (
toSampleSet: toSampleSetFn, toSampleSet: toSampleSetFn,
operation: OperationType.combination, operation: OperationType.arithmeticOperation,
t1: t, t1: t,
t2: t, t2: t,
) => { ) => {
@ -163,7 +210,7 @@ module T = {
| _ => 1000 | _ => 1000
} }
let chooseConvolutionOrMonteCarlo = (t1: t, t2: t) => let chooseConvolutionOrMonteCarlo = (t2: t, t1: t) =>
expectedConvolutionCost(t1) * expectedConvolutionCost(t2) > 10000 expectedConvolutionCost(t1) * expectedConvolutionCost(t2) > 10000
? #CalculateWithMonteCarlo ? #CalculateWithMonteCarlo
: #CalculateWithConvolution : #CalculateWithConvolution
@ -189,22 +236,23 @@ module T = {
} }
} }
let pointwiseCombination = (toPointSet: toPointSetFn, operation, t1: t, t2: t): result< //TODO: Add faster pointwiseCombine fn
let pointwiseCombination = (toPointSet: toPointSetFn, operation, t2: t, t1: t): result<
t, t,
error, error,
> => { > => {
E.R.merge(toPointSet(t1), toPointSet(t2)) E.R.merge(toPointSet(t1), toPointSet(t2))
|> E.R.fmap(((t1, t2)) => |> E.R.fmap(((t1, t2)) =>
PointSetDist.combinePointwise(OperationType.combinationToFn(operation), t1, t2) PointSetDist.combinePointwise(OperationType.arithmeticToFn(operation), t1, t2)
) )
|> E.R.fmap(r => #PointSet(r)) |> E.R.fmap(r => #PointSet(r))
} }
let pointwiseCombinationFloat = ( let pointwiseCombinationFloat = (
toPointSet: toPointSetFn, toPointSet: toPointSetFn,
operation: OperationType.combination, operation: OperationType.arithmeticOperation,
t: t,
f: float, f: float,
t: t,
): result<t, error> => { ): result<t, error> => {
switch operation { switch operation {
| #Add | #Subtract => Error(DistributionVerticalShiftIsInvalid) | #Add | #Subtract => Error(DistributionVerticalShiftIsInvalid)
@ -259,36 +307,41 @@ module OmniRunner = {
switch reCall(~value=r, ~fnName=#toDist(#toPointSet), ()) { switch reCall(~value=r, ~fnName=#toDist(#toPointSet), ()) {
| #Dist(#PointSet(p)) => Ok(p) | #Dist(#PointSet(p)) => Ok(p)
| #Error(r) => Error(r) | #Error(r) => Error(r)
| _ => Error(Other("Impossible error")) | _ => Error(ImpossiblePath)
} }
} }
let toSampleSet = r => { let toSampleSet = r => {
switch reCall(~value=r, ~fnName=#toDist(#toSampleSet(sampleCount)), ()) { switch reCall(~value=r, ~fnName=#toDist(#toSampleSet(sampleCount)), ()) {
| #Dist(#SampleSet(p)) => Ok(p) | #Dist(#SampleSet(p)) => Ok(p)
| #Error(r) => Error(r) | #Error(r) => Error(r)
| _ => Error(Other("Impossible error")) | _ => Error(ImpossiblePath)
} }
} }
switch (fnName, value) { switch fnName {
// | (#toFloat(n), v) => toFloat(toPointSet, v, n) // | (#toFloat(n), v) => toFloat(toPointSet, v, n)
| (#toFloat(fnName), _) => | #toFloat(fnName) =>
T.toFloat(toPointSet, fnName, value) |> E.R.fmap(r => #Float(r)) |> fromResult T.toFloat(toPointSet, fnName, value) |> E.R.fmap(r => #Float(r)) |> fromResult
| (#toDist(#normalize), #PointSet(r)) => #Dist(#PointSet(PointSetDist.T.normalize(r))) | #toDist(#normalize) => value |> T.normalize |> (r => #Dist(r))
| (#toDist(#normalize), #Symbolic(_)) => #Dist(value) | #toDist(#truncate(left, right)) =>
| (#toDist(#normalize), #SampleSet(_)) => #Dist(value) value |> T.Truncate.run(toPointSet, left, right) |> E.R.fmap(r => #Dist(r)) |> fromResult
| (#toDist(#toPointSet), _) => | #toDist(#toPointSet) =>
value |> T.toPointSet(xyPointLength) |> E.R.fmap(r => #Dist(#PointSet(r))) |> fromResult value |> T.toPointSet(xyPointLength) |> E.R.fmap(r => #Dist(#PointSet(r))) |> fromResult
| (#toDist(#toSampleSet(n)), _) => | #toDist(#toSampleSet(n)) =>
value |> T.sampleN(n) |> E.R.fmap(r => #Dist(#SampleSet(r))) |> fromResult value |> T.sampleN(n) |> E.R.fmap(r => #Dist(#SampleSet(r))) |> fromResult
| (#toDistCombination(#Algebraic, _, #Float(_)), _) => #Error(NotYetImplemented) | #toDistCombination(#Algebraic, _, #Float(_)) => #Error(NotYetImplemented)
| (#toDistCombination(#Algebraic, operation, #Dist(p2)), p1) => | #toDistCombination(#Algebraic, operation, #Dist(value2)) =>
T.AlgebraicCombination.run(toPointSet, toSampleSet, operation, p1, p2) value
|> T.AlgebraicCombination.run(toPointSet, toSampleSet, operation, value2)
|> E.R.fmap(r => #Dist(r)) |> E.R.fmap(r => #Dist(r))
|> fromResult |> fromResult
| (#toDistCombination(#Pointwise, operation, #Dist(p2)), p1) => | #toDistCombination(#Pointwise, operation, #Dist(value2)) =>
T.pointwiseCombination(toPointSet, operation, p1, p2) |> E.R.fmap(r => #Dist(r)) |> fromResult value
| (#toDistCombination(#Pointwise, operation, #Float(f)), _) => |> T.pointwiseCombination(toPointSet, operation, value2)
T.pointwiseCombinationFloat(toPointSet, operation, value, f) |> E.R.fmap(r => #Dist(r))
|> fromResult
| #toDistCombination(#Pointwise, operation, #Float(f)) =>
value
|> T.pointwiseCombinationFloat(toPointSet, operation, f)
|> E.R.fmap(r => #Dist(r)) |> E.R.fmap(r => #Dist(r))
|> fromResult |> fromResult
} }