Changed genericDist from being a polymorphic variant

This commit is contained in:
Ozzie Gooen 2022-03-31 14:15:21 -04:00
parent 4b3f24b38d
commit 680726e8b0
5 changed files with 62 additions and 49 deletions

View File

@ -6,10 +6,10 @@ let env: GenericDist_GenericOperation.env = {
xyPointLength: 100, xyPointLength: 100,
} }
let normalDist: GenericDist_Types.genericDist = #Symbolic(#Normal({mean: 5.0, stdev: 2.0})) let normalDist: GenericDist_Types.genericDist = Symbolic(#Normal({mean: 5.0, stdev: 2.0}))
let normalDist10: GenericDist_Types.genericDist = #Symbolic(#Normal({mean: 10.0, stdev: 2.0})) let normalDist10: GenericDist_Types.genericDist = Symbolic(#Normal({mean: 10.0, stdev: 2.0}))
let normalDist20: GenericDist_Types.genericDist = #Symbolic(#Normal({mean: 20.0, stdev: 2.0})) let normalDist20: GenericDist_Types.genericDist = Symbolic(#Normal({mean: 20.0, stdev: 2.0}))
let uniformDist: GenericDist_Types.genericDist = #Symbolic(#Uniform({low: 9.0, high: 10.0})) let uniformDist: GenericDist_Types.genericDist = Symbolic(#Uniform({low: 9.0, high: 10.0}))
let {toFloat, toDist, toString, toError} = module(GenericDist_GenericOperation.Output) let {toFloat, toDist, toString, toError} = module(GenericDist_GenericOperation.Output)
let {run} = module(GenericDist_GenericOperation) let {run} = module(GenericDist_GenericOperation)
@ -57,7 +57,7 @@ describe("toPointSet", () => {
test("on sample set distribution with under 4 points", () => { test("on sample set distribution with under 4 points", () => {
let result = let result =
run(#fromDist(#toDist(#toPointSet), #SampleSet([0.0, 1.0, 2.0, 3.0])))->outputMap( run(#fromDist(#toDist(#toPointSet), SampleSet([0.0, 1.0, 2.0, 3.0])))->outputMap(
#fromDist(#toFloat(#Mean)), #fromDist(#toFloat(#Mean)),
) )
expect(result)->toEqual(GenDistError(Other("Converting sampleSet to pointSet failed"))) expect(result)->toEqual(GenDistError(Other("Converting sampleSet to pointSet failed")))

View File

@ -8,25 +8,25 @@ type pointwiseAddFn = (t, t) => result<t, error>
let sampleN = (t: t, n) => let sampleN = (t: t, n) =>
switch t { switch t {
| #PointSet(r) => Ok(PointSetDist.sampleNRendered(n, r)) | PointSet(r) => Ok(PointSetDist.sampleNRendered(n, r))
| #Symbolic(r) => Ok(SymbolicDist.T.sampleN(n, r)) | Symbolic(r) => Ok(SymbolicDist.T.sampleN(n, r))
| #SampleSet(_) => Error(GenericDist_Types.NotYetImplemented) | SampleSet(_) => Error(GenericDist_Types.NotYetImplemented)
} }
let fromFloat = (f: float) => #Symbolic(SymbolicDist.Float.make(f)) let fromFloat = (f: float): t => Symbolic(SymbolicDist.Float.make(f))
let toString = (t: t) => let toString = (t: t) =>
switch t { switch t {
| #PointSet(_) => "Point Set Distribution" | PointSet(_) => "Point Set Distribution"
| #Symbolic(r) => SymbolicDist.T.toString(r) | Symbolic(r) => SymbolicDist.T.toString(r)
| #SampleSet(_) => "Sample Set Distribution" | SampleSet(_) => "Sample Set Distribution"
} }
let normalize = (t: t) => let normalize = (t: t): t =>
switch t { switch t {
| #PointSet(r) => #PointSet(PointSetDist.T.normalize(r)) | PointSet(r) => PointSet(PointSetDist.T.normalize(r))
| #Symbolic(_) => t | Symbolic(_) => t
| #SampleSet(_) => t | SampleSet(_) => t
} }
let toFloatOperation = ( let toFloatOperation = (
@ -34,8 +34,8 @@ let toFloatOperation = (
~toPointSetFn: toPointSetFn, ~toPointSetFn: toPointSetFn,
~distToFloatOperation: Operation.distToFloatOperation, ~distToFloatOperation: Operation.distToFloatOperation,
) => { ) => {
let symbolicSolution = switch t { let symbolicSolution = switch (t: t) {
| #Symbolic(r) => | Symbolic(r) =>
switch SymbolicDist.T.operate(distToFloatOperation, r) { switch SymbolicDist.T.operate(distToFloatOperation, r) {
| Ok(f) => Some(f) | Ok(f) => Some(f)
| _ => None | _ => None
@ -53,10 +53,10 @@ let toFloatOperation = (
// This is tricky because the case of discrete distributions. // This is tricky because the case of discrete distributions.
// Also, change the outputXYPoints/pointSetDistLength details // Also, change the outputXYPoints/pointSetDistLength details
let toPointSet = (~xyPointLength, ~sampleCount, t): result<PointSetTypes.pointSetDist, error> => { let toPointSet = (~xyPointLength, ~sampleCount, t): result<PointSetTypes.pointSetDist, error> => {
switch t { switch (t: t) {
| #PointSet(pointSet) => Ok(pointSet) | PointSet(pointSet) => Ok(pointSet)
| #Symbolic(r) => Ok(SymbolicDist.T.toPointSetDist(xyPointLength, r)) | Symbolic(r) => Ok(SymbolicDist.T.toPointSetDist(xyPointLength, r))
| #SampleSet(r) => { | SampleSet(r) => {
let response = SampleSet.toPointSetDist( let response = SampleSet.toPointSetDist(
~samples=r, ~samples=r,
~samplingInputs={ ~samplingInputs={
@ -76,11 +76,11 @@ let toPointSet = (~xyPointLength, ~sampleCount, t): result<PointSetTypes.pointSe
} }
module Truncate = { module Truncate = {
let trySymbolicSimplification = (leftCutoff, rightCutoff, t): option<t> => let trySymbolicSimplification = (leftCutoff, rightCutoff, t: t): option<t> =>
switch (leftCutoff, rightCutoff, t) { switch (leftCutoff, rightCutoff, t) {
| (None, None, _) => None | (None, None, _) => None
| (lc, rc, #Symbolic(#Uniform(u))) if lc < rc => | (lc, rc, Symbolic(#Uniform(u))) if lc < rc =>
Some(#Symbolic(#Uniform(SymbolicDist.Uniform.truncate(lc, rc, u)))) Some(Symbolic(#Uniform(SymbolicDist.Uniform.truncate(lc, rc, u))))
| _ => None | _ => None
} }
@ -98,9 +98,9 @@ module Truncate = {
switch trySymbolicSimplification(leftCutoff, rightCutoff, t) { switch trySymbolicSimplification(leftCutoff, rightCutoff, t) {
| Some(r) => Ok(r) | Some(r) => Ok(r)
| None => | None =>
toPointSetFn(t)->E.R2.fmap(t => toPointSetFn(t)->E.R2.fmap(t => {
#PointSet(PointSetDist.T.truncate(leftCutoff, rightCutoff, t)) GenericDist_Types.PointSet(PointSetDist.T.truncate(leftCutoff, rightCutoff, t))
) })
} }
} }
} }
@ -122,7 +122,7 @@ module AlgebraicCombination = {
t2: t, t2: t,
): option<result<SymbolicDistTypes.symbolicDist, string>> => ): option<result<SymbolicDistTypes.symbolicDist, string>> =>
switch (arithmeticOperation, t1, t2) { switch (arithmeticOperation, t1, t2) {
| (arithmeticOperation, #Symbolic(d1), #Symbolic(d2)) => | (arithmeticOperation, Symbolic(d1), Symbolic(d2)) =>
switch SymbolicDist.T.tryAnalyticalSimplification(d1, d2, arithmeticOperation) { switch SymbolicDist.T.tryAnalyticalSimplification(d1, d2, arithmeticOperation) {
| #AnalyticalSolution(symbolicDist) => Some(Ok(symbolicDist)) | #AnalyticalSolution(symbolicDist) => Some(Ok(symbolicDist))
| #Error(er) => Some(Error(er)) | #Error(er) => Some(Error(er))
@ -156,11 +156,11 @@ module AlgebraicCombination = {
//I'm (Ozzie) really just guessing here, very little idea what's best //I'm (Ozzie) really just guessing here, very little idea what's best
let expectedConvolutionCost: t => int = x => let expectedConvolutionCost: t => int = x =>
switch x { switch x {
| #Symbolic(#Float(_)) => 1 | Symbolic(#Float(_)) => 1
| #Symbolic(_) => 1000 | Symbolic(_) => 1000
| #PointSet(Discrete(m)) => m.xyShape->XYShape.T.length | PointSet(Discrete(m)) => m.xyShape->XYShape.T.length
| #PointSet(Mixed(_)) => 1000 | PointSet(Mixed(_)) => 1000
| #PointSet(Continuous(_)) => 1000 | PointSet(Continuous(_)) => 1000
| _ => 1000 | _ => 1000
} }
@ -177,14 +177,24 @@ module AlgebraicCombination = {
~t2: t, ~t2: t,
): result<t, error> => { ): result<t, error> => {
switch tryAnalyticalSimplification(arithmeticOperation, t1, t2) { switch tryAnalyticalSimplification(arithmeticOperation, t1, t2) {
| Some(Ok(symbolicDist)) => Ok(#Symbolic(symbolicDist)) | Some(Ok(symbolicDist)) => Ok(Symbolic(symbolicDist))
| Some(Error(e)) => Error(Other(e)) | Some(Error(e)) => Error(Other(e))
| None => | None =>
switch chooseConvolutionOrMonteCarlo(t1, t2) { switch chooseConvolutionOrMonteCarlo(t1, t2) {
| #CalculateWithMonteCarlo => | #CalculateWithMonteCarlo =>
runMonteCarlo(toSampleSetFn, arithmeticOperation, t1, t2)->E.R2.fmap(r => #SampleSet(r)) runMonteCarlo(
toSampleSetFn,
arithmeticOperation,
t1,
t2,
)->E.R2.fmap(r => GenericDist_Types.SampleSet(r))
| #CalculateWithConvolution => | #CalculateWithConvolution =>
runConvolution(toPointSetFn, arithmeticOperation, t1, t2)->E.R2.fmap(r => #PointSet(r)) runConvolution(
toPointSetFn,
arithmeticOperation,
t1,
t2,
)->E.R2.fmap(r => GenericDist_Types.PointSet(r))
} }
} }
} }
@ -207,7 +217,7 @@ let pointwiseCombination = (
t2, t2,
) )
) )
->E.R2.fmap(r => #PointSet(r)) ->E.R2.fmap(r => GenericDist_Types.PointSet(r))
} }
let pointwiseCombinationFloat = ( let pointwiseCombinationFloat = (
@ -232,7 +242,7 @@ let pointwiseCombinationFloat = (
) )
}) })
} }
m->E.R2.fmap(r => #PointSet(r)) m->E.R2.fmap(r => GenericDist_Types.PointSet(r))
} }
//Note: The result should always cumulatively sum to 1. This would be good to test. //Note: The result should always cumulatively sum to 1. This would be good to test.

View File

@ -71,14 +71,14 @@ let rec run = (~env, functionCallInfo: functionCallInfo): outputType => {
let toPointSetFn = r => { let toPointSetFn = r => {
switch reCall(~functionCallInfo=#fromDist(#toDist(#toPointSet), r), ()) { switch reCall(~functionCallInfo=#fromDist(#toDist(#toPointSet), r), ()) {
| Dist(#PointSet(p)) => Ok(p) | Dist(PointSet(p)) => Ok(p)
| e => Error(OutputLocal.toErrorOrUnreachable(e)) | e => Error(OutputLocal.toErrorOrUnreachable(e))
} }
} }
let toSampleSetFn = r => { let toSampleSetFn = r => {
switch reCall(~functionCallInfo=#fromDist(#toDist(#toSampleSet(sampleCount)), r), ()) { switch reCall(~functionCallInfo=#fromDist(#toDist(#toSampleSet(sampleCount)), r), ()) {
| Dist(#SampleSet(p)) => Ok(p) | Dist(SampleSet(p)) => Ok(p)
| e => Error(OutputLocal.toErrorOrUnreachable(e)) | e => Error(OutputLocal.toErrorOrUnreachable(e))
} }
} }
@ -114,10 +114,10 @@ let rec run = (~env, functionCallInfo: functionCallInfo): outputType => {
| #toDist(#toPointSet) => | #toDist(#toPointSet) =>
dist dist
->GenericDist.toPointSet(~xyPointLength, ~sampleCount) ->GenericDist.toPointSet(~xyPointLength, ~sampleCount)
->E.R2.fmap(r => Dist(#PointSet(r))) ->E.R2.fmap(r => Dist(PointSet(r)))
->OutputLocal.fromResult ->OutputLocal.fromResult
| #toDist(#toSampleSet(n)) => | #toDist(#toSampleSet(n)) =>
dist->GenericDist.sampleN(n)->E.R2.fmap(r => Dist(#SampleSet(r)))->OutputLocal.fromResult dist->GenericDist.sampleN(n)->E.R2.fmap(r => Dist(SampleSet(r)))->OutputLocal.fromResult
| #toDistCombination(#Algebraic, _, #Float(_)) => GenDistError(NotYetImplemented) | #toDistCombination(#Algebraic, _, #Float(_)) => GenDistError(NotYetImplemented)
| #toDistCombination(#Algebraic, arithmeticOperation, #Dist(t2)) => | #toDistCombination(#Algebraic, arithmeticOperation, #Dist(t2)) =>
dist dist

View File

@ -15,7 +15,11 @@ let runFromDist: (
~functionCallInfo: GenericDist_Types.Operation.fromDist, ~functionCallInfo: GenericDist_Types.Operation.fromDist,
GenericDist_Types.genericDist, GenericDist_Types.genericDist,
) => outputType ) => outputType
let runFromFloat: (~env: env, ~functionCallInfo: GenericDist_Types.Operation.fromDist, float) => outputType let runFromFloat: (
~env: env,
~functionCallInfo: GenericDist_Types.Operation.fromDist,
float,
) => outputType
module Output: { module Output: {
type t = outputType type t = outputType

View File

@ -1,8 +1,7 @@
type genericDist = [ type genericDist =
| #PointSet(PointSetTypes.pointSetDist) | PointSet(PointSetTypes.pointSetDist)
| #SampleSet(array<float>) | SampleSet(array<float>)
| #Symbolic(SymbolicDistTypes.symbolicDist) | Symbolic(SymbolicDistTypes.symbolicDist)
]
type error = type error =
| NotYetImplemented | NotYetImplemented