Changed genericDist from being a polymorphic variant
This commit is contained in:
parent
4b3f24b38d
commit
680726e8b0
|
@ -6,10 +6,10 @@ let env: GenericDist_GenericOperation.env = {
|
||||||
xyPointLength: 100,
|
xyPointLength: 100,
|
||||||
}
|
}
|
||||||
|
|
||||||
let normalDist: GenericDist_Types.genericDist = #Symbolic(#Normal({mean: 5.0, stdev: 2.0}))
|
let normalDist: GenericDist_Types.genericDist = Symbolic(#Normal({mean: 5.0, stdev: 2.0}))
|
||||||
let normalDist10: GenericDist_Types.genericDist = #Symbolic(#Normal({mean: 10.0, stdev: 2.0}))
|
let normalDist10: GenericDist_Types.genericDist = Symbolic(#Normal({mean: 10.0, stdev: 2.0}))
|
||||||
let normalDist20: GenericDist_Types.genericDist = #Symbolic(#Normal({mean: 20.0, stdev: 2.0}))
|
let normalDist20: GenericDist_Types.genericDist = Symbolic(#Normal({mean: 20.0, stdev: 2.0}))
|
||||||
let uniformDist: GenericDist_Types.genericDist = #Symbolic(#Uniform({low: 9.0, high: 10.0}))
|
let uniformDist: GenericDist_Types.genericDist = Symbolic(#Uniform({low: 9.0, high: 10.0}))
|
||||||
|
|
||||||
let {toFloat, toDist, toString, toError} = module(GenericDist_GenericOperation.Output)
|
let {toFloat, toDist, toString, toError} = module(GenericDist_GenericOperation.Output)
|
||||||
let {run} = module(GenericDist_GenericOperation)
|
let {run} = module(GenericDist_GenericOperation)
|
||||||
|
@ -57,7 +57,7 @@ describe("toPointSet", () => {
|
||||||
|
|
||||||
test("on sample set distribution with under 4 points", () => {
|
test("on sample set distribution with under 4 points", () => {
|
||||||
let result =
|
let result =
|
||||||
run(#fromDist(#toDist(#toPointSet), #SampleSet([0.0, 1.0, 2.0, 3.0])))->outputMap(
|
run(#fromDist(#toDist(#toPointSet), SampleSet([0.0, 1.0, 2.0, 3.0])))->outputMap(
|
||||||
#fromDist(#toFloat(#Mean)),
|
#fromDist(#toFloat(#Mean)),
|
||||||
)
|
)
|
||||||
expect(result)->toEqual(GenDistError(Other("Converting sampleSet to pointSet failed")))
|
expect(result)->toEqual(GenDistError(Other("Converting sampleSet to pointSet failed")))
|
||||||
|
|
|
@ -8,25 +8,25 @@ type pointwiseAddFn = (t, t) => result<t, error>
|
||||||
|
|
||||||
let sampleN = (t: t, n) =>
|
let sampleN = (t: t, n) =>
|
||||||
switch t {
|
switch t {
|
||||||
| #PointSet(r) => Ok(PointSetDist.sampleNRendered(n, r))
|
| PointSet(r) => Ok(PointSetDist.sampleNRendered(n, r))
|
||||||
| #Symbolic(r) => Ok(SymbolicDist.T.sampleN(n, r))
|
| Symbolic(r) => Ok(SymbolicDist.T.sampleN(n, r))
|
||||||
| #SampleSet(_) => Error(GenericDist_Types.NotYetImplemented)
|
| SampleSet(_) => Error(GenericDist_Types.NotYetImplemented)
|
||||||
}
|
}
|
||||||
|
|
||||||
let fromFloat = (f: float) => #Symbolic(SymbolicDist.Float.make(f))
|
let fromFloat = (f: float): t => Symbolic(SymbolicDist.Float.make(f))
|
||||||
|
|
||||||
let toString = (t: t) =>
|
let toString = (t: t) =>
|
||||||
switch t {
|
switch t {
|
||||||
| #PointSet(_) => "Point Set Distribution"
|
| PointSet(_) => "Point Set Distribution"
|
||||||
| #Symbolic(r) => SymbolicDist.T.toString(r)
|
| Symbolic(r) => SymbolicDist.T.toString(r)
|
||||||
| #SampleSet(_) => "Sample Set Distribution"
|
| SampleSet(_) => "Sample Set Distribution"
|
||||||
}
|
}
|
||||||
|
|
||||||
let normalize = (t: t) =>
|
let normalize = (t: t): t =>
|
||||||
switch t {
|
switch t {
|
||||||
| #PointSet(r) => #PointSet(PointSetDist.T.normalize(r))
|
| PointSet(r) => PointSet(PointSetDist.T.normalize(r))
|
||||||
| #Symbolic(_) => t
|
| Symbolic(_) => t
|
||||||
| #SampleSet(_) => t
|
| SampleSet(_) => t
|
||||||
}
|
}
|
||||||
|
|
||||||
let toFloatOperation = (
|
let toFloatOperation = (
|
||||||
|
@ -34,8 +34,8 @@ let toFloatOperation = (
|
||||||
~toPointSetFn: toPointSetFn,
|
~toPointSetFn: toPointSetFn,
|
||||||
~distToFloatOperation: Operation.distToFloatOperation,
|
~distToFloatOperation: Operation.distToFloatOperation,
|
||||||
) => {
|
) => {
|
||||||
let symbolicSolution = switch t {
|
let symbolicSolution = switch (t: t) {
|
||||||
| #Symbolic(r) =>
|
| Symbolic(r) =>
|
||||||
switch SymbolicDist.T.operate(distToFloatOperation, r) {
|
switch SymbolicDist.T.operate(distToFloatOperation, r) {
|
||||||
| Ok(f) => Some(f)
|
| Ok(f) => Some(f)
|
||||||
| _ => None
|
| _ => None
|
||||||
|
@ -53,10 +53,10 @@ let toFloatOperation = (
|
||||||
// This is tricky because the case of discrete distributions.
|
// This is tricky because the case of discrete distributions.
|
||||||
// Also, change the outputXYPoints/pointSetDistLength details
|
// Also, change the outputXYPoints/pointSetDistLength details
|
||||||
let toPointSet = (~xyPointLength, ~sampleCount, t): result<PointSetTypes.pointSetDist, error> => {
|
let toPointSet = (~xyPointLength, ~sampleCount, t): result<PointSetTypes.pointSetDist, error> => {
|
||||||
switch t {
|
switch (t: t) {
|
||||||
| #PointSet(pointSet) => Ok(pointSet)
|
| PointSet(pointSet) => Ok(pointSet)
|
||||||
| #Symbolic(r) => Ok(SymbolicDist.T.toPointSetDist(xyPointLength, r))
|
| Symbolic(r) => Ok(SymbolicDist.T.toPointSetDist(xyPointLength, r))
|
||||||
| #SampleSet(r) => {
|
| SampleSet(r) => {
|
||||||
let response = SampleSet.toPointSetDist(
|
let response = SampleSet.toPointSetDist(
|
||||||
~samples=r,
|
~samples=r,
|
||||||
~samplingInputs={
|
~samplingInputs={
|
||||||
|
@ -76,11 +76,11 @@ let toPointSet = (~xyPointLength, ~sampleCount, t): result<PointSetTypes.pointSe
|
||||||
}
|
}
|
||||||
|
|
||||||
module Truncate = {
|
module Truncate = {
|
||||||
let trySymbolicSimplification = (leftCutoff, rightCutoff, t): option<t> =>
|
let trySymbolicSimplification = (leftCutoff, rightCutoff, t: t): option<t> =>
|
||||||
switch (leftCutoff, rightCutoff, t) {
|
switch (leftCutoff, rightCutoff, t) {
|
||||||
| (None, None, _) => None
|
| (None, None, _) => None
|
||||||
| (lc, rc, #Symbolic(#Uniform(u))) if lc < rc =>
|
| (lc, rc, Symbolic(#Uniform(u))) if lc < rc =>
|
||||||
Some(#Symbolic(#Uniform(SymbolicDist.Uniform.truncate(lc, rc, u))))
|
Some(Symbolic(#Uniform(SymbolicDist.Uniform.truncate(lc, rc, u))))
|
||||||
| _ => None
|
| _ => None
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -98,9 +98,9 @@ module Truncate = {
|
||||||
switch trySymbolicSimplification(leftCutoff, rightCutoff, t) {
|
switch trySymbolicSimplification(leftCutoff, rightCutoff, t) {
|
||||||
| Some(r) => Ok(r)
|
| Some(r) => Ok(r)
|
||||||
| None =>
|
| None =>
|
||||||
toPointSetFn(t)->E.R2.fmap(t =>
|
toPointSetFn(t)->E.R2.fmap(t => {
|
||||||
#PointSet(PointSetDist.T.truncate(leftCutoff, rightCutoff, t))
|
GenericDist_Types.PointSet(PointSetDist.T.truncate(leftCutoff, rightCutoff, t))
|
||||||
)
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -122,7 +122,7 @@ module AlgebraicCombination = {
|
||||||
t2: t,
|
t2: t,
|
||||||
): option<result<SymbolicDistTypes.symbolicDist, string>> =>
|
): option<result<SymbolicDistTypes.symbolicDist, string>> =>
|
||||||
switch (arithmeticOperation, t1, t2) {
|
switch (arithmeticOperation, t1, t2) {
|
||||||
| (arithmeticOperation, #Symbolic(d1), #Symbolic(d2)) =>
|
| (arithmeticOperation, Symbolic(d1), Symbolic(d2)) =>
|
||||||
switch SymbolicDist.T.tryAnalyticalSimplification(d1, d2, arithmeticOperation) {
|
switch SymbolicDist.T.tryAnalyticalSimplification(d1, d2, arithmeticOperation) {
|
||||||
| #AnalyticalSolution(symbolicDist) => Some(Ok(symbolicDist))
|
| #AnalyticalSolution(symbolicDist) => Some(Ok(symbolicDist))
|
||||||
| #Error(er) => Some(Error(er))
|
| #Error(er) => Some(Error(er))
|
||||||
|
@ -156,11 +156,11 @@ module AlgebraicCombination = {
|
||||||
//I'm (Ozzie) really just guessing here, very little idea what's best
|
//I'm (Ozzie) really just guessing here, very little idea what's best
|
||||||
let expectedConvolutionCost: t => int = x =>
|
let expectedConvolutionCost: t => int = x =>
|
||||||
switch x {
|
switch x {
|
||||||
| #Symbolic(#Float(_)) => 1
|
| Symbolic(#Float(_)) => 1
|
||||||
| #Symbolic(_) => 1000
|
| Symbolic(_) => 1000
|
||||||
| #PointSet(Discrete(m)) => m.xyShape->XYShape.T.length
|
| PointSet(Discrete(m)) => m.xyShape->XYShape.T.length
|
||||||
| #PointSet(Mixed(_)) => 1000
|
| PointSet(Mixed(_)) => 1000
|
||||||
| #PointSet(Continuous(_)) => 1000
|
| PointSet(Continuous(_)) => 1000
|
||||||
| _ => 1000
|
| _ => 1000
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -177,14 +177,24 @@ module AlgebraicCombination = {
|
||||||
~t2: t,
|
~t2: t,
|
||||||
): result<t, error> => {
|
): result<t, error> => {
|
||||||
switch tryAnalyticalSimplification(arithmeticOperation, t1, t2) {
|
switch tryAnalyticalSimplification(arithmeticOperation, t1, t2) {
|
||||||
| Some(Ok(symbolicDist)) => Ok(#Symbolic(symbolicDist))
|
| Some(Ok(symbolicDist)) => Ok(Symbolic(symbolicDist))
|
||||||
| Some(Error(e)) => Error(Other(e))
|
| Some(Error(e)) => Error(Other(e))
|
||||||
| None =>
|
| None =>
|
||||||
switch chooseConvolutionOrMonteCarlo(t1, t2) {
|
switch chooseConvolutionOrMonteCarlo(t1, t2) {
|
||||||
| #CalculateWithMonteCarlo =>
|
| #CalculateWithMonteCarlo =>
|
||||||
runMonteCarlo(toSampleSetFn, arithmeticOperation, t1, t2)->E.R2.fmap(r => #SampleSet(r))
|
runMonteCarlo(
|
||||||
|
toSampleSetFn,
|
||||||
|
arithmeticOperation,
|
||||||
|
t1,
|
||||||
|
t2,
|
||||||
|
)->E.R2.fmap(r => GenericDist_Types.SampleSet(r))
|
||||||
| #CalculateWithConvolution =>
|
| #CalculateWithConvolution =>
|
||||||
runConvolution(toPointSetFn, arithmeticOperation, t1, t2)->E.R2.fmap(r => #PointSet(r))
|
runConvolution(
|
||||||
|
toPointSetFn,
|
||||||
|
arithmeticOperation,
|
||||||
|
t1,
|
||||||
|
t2,
|
||||||
|
)->E.R2.fmap(r => GenericDist_Types.PointSet(r))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -207,7 +217,7 @@ let pointwiseCombination = (
|
||||||
t2,
|
t2,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
->E.R2.fmap(r => #PointSet(r))
|
->E.R2.fmap(r => GenericDist_Types.PointSet(r))
|
||||||
}
|
}
|
||||||
|
|
||||||
let pointwiseCombinationFloat = (
|
let pointwiseCombinationFloat = (
|
||||||
|
@ -232,7 +242,7 @@ let pointwiseCombinationFloat = (
|
||||||
)
|
)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
m->E.R2.fmap(r => #PointSet(r))
|
m->E.R2.fmap(r => GenericDist_Types.PointSet(r))
|
||||||
}
|
}
|
||||||
|
|
||||||
//Note: The result should always cumulatively sum to 1. This would be good to test.
|
//Note: The result should always cumulatively sum to 1. This would be good to test.
|
||||||
|
|
|
@ -71,14 +71,14 @@ let rec run = (~env, functionCallInfo: functionCallInfo): outputType => {
|
||||||
|
|
||||||
let toPointSetFn = r => {
|
let toPointSetFn = r => {
|
||||||
switch reCall(~functionCallInfo=#fromDist(#toDist(#toPointSet), r), ()) {
|
switch reCall(~functionCallInfo=#fromDist(#toDist(#toPointSet), r), ()) {
|
||||||
| Dist(#PointSet(p)) => Ok(p)
|
| Dist(PointSet(p)) => Ok(p)
|
||||||
| e => Error(OutputLocal.toErrorOrUnreachable(e))
|
| e => Error(OutputLocal.toErrorOrUnreachable(e))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
let toSampleSetFn = r => {
|
let toSampleSetFn = r => {
|
||||||
switch reCall(~functionCallInfo=#fromDist(#toDist(#toSampleSet(sampleCount)), r), ()) {
|
switch reCall(~functionCallInfo=#fromDist(#toDist(#toSampleSet(sampleCount)), r), ()) {
|
||||||
| Dist(#SampleSet(p)) => Ok(p)
|
| Dist(SampleSet(p)) => Ok(p)
|
||||||
| e => Error(OutputLocal.toErrorOrUnreachable(e))
|
| e => Error(OutputLocal.toErrorOrUnreachable(e))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -114,10 +114,10 @@ let rec run = (~env, functionCallInfo: functionCallInfo): outputType => {
|
||||||
| #toDist(#toPointSet) =>
|
| #toDist(#toPointSet) =>
|
||||||
dist
|
dist
|
||||||
->GenericDist.toPointSet(~xyPointLength, ~sampleCount)
|
->GenericDist.toPointSet(~xyPointLength, ~sampleCount)
|
||||||
->E.R2.fmap(r => Dist(#PointSet(r)))
|
->E.R2.fmap(r => Dist(PointSet(r)))
|
||||||
->OutputLocal.fromResult
|
->OutputLocal.fromResult
|
||||||
| #toDist(#toSampleSet(n)) =>
|
| #toDist(#toSampleSet(n)) =>
|
||||||
dist->GenericDist.sampleN(n)->E.R2.fmap(r => Dist(#SampleSet(r)))->OutputLocal.fromResult
|
dist->GenericDist.sampleN(n)->E.R2.fmap(r => Dist(SampleSet(r)))->OutputLocal.fromResult
|
||||||
| #toDistCombination(#Algebraic, _, #Float(_)) => GenDistError(NotYetImplemented)
|
| #toDistCombination(#Algebraic, _, #Float(_)) => GenDistError(NotYetImplemented)
|
||||||
| #toDistCombination(#Algebraic, arithmeticOperation, #Dist(t2)) =>
|
| #toDistCombination(#Algebraic, arithmeticOperation, #Dist(t2)) =>
|
||||||
dist
|
dist
|
||||||
|
|
|
@ -15,7 +15,11 @@ let runFromDist: (
|
||||||
~functionCallInfo: GenericDist_Types.Operation.fromDist,
|
~functionCallInfo: GenericDist_Types.Operation.fromDist,
|
||||||
GenericDist_Types.genericDist,
|
GenericDist_Types.genericDist,
|
||||||
) => outputType
|
) => outputType
|
||||||
let runFromFloat: (~env: env, ~functionCallInfo: GenericDist_Types.Operation.fromDist, float) => outputType
|
let runFromFloat: (
|
||||||
|
~env: env,
|
||||||
|
~functionCallInfo: GenericDist_Types.Operation.fromDist,
|
||||||
|
float,
|
||||||
|
) => outputType
|
||||||
|
|
||||||
module Output: {
|
module Output: {
|
||||||
type t = outputType
|
type t = outputType
|
||||||
|
|
|
@ -1,8 +1,7 @@
|
||||||
type genericDist = [
|
type genericDist =
|
||||||
| #PointSet(PointSetTypes.pointSetDist)
|
| PointSet(PointSetTypes.pointSetDist)
|
||||||
| #SampleSet(array<float>)
|
| SampleSet(array<float>)
|
||||||
| #Symbolic(SymbolicDistTypes.symbolicDist)
|
| Symbolic(SymbolicDistTypes.symbolicDist)
|
||||||
]
|
|
||||||
|
|
||||||
type error =
|
type error =
|
||||||
| NotYetImplemented
|
| NotYetImplemented
|
||||||
|
|
Loading…
Reference in New Issue
Block a user