Converted most of Operation to not be polymorphic

This commit is contained in:
Ozzie Gooen 2022-03-31 14:51:42 -04:00
parent 680726e8b0
commit 15534b10ce
4 changed files with 82 additions and 85 deletions

View File

@ -22,14 +22,14 @@ let toExt: option<'a> => 'a = E.O.toExt(
describe("normalize", () => { describe("normalize", () => {
test("has no impact on normal dist", () => { test("has no impact on normal dist", () => {
let result = run(#fromDist(#toDist(#normalize), normalDist)) let result = run(FromDist(ToDist(Normalize), normalDist))
expect(result)->toEqual(Dist(normalDist)) expect(result)->toEqual(Dist(normalDist))
}) })
}) })
describe("mean", () => { describe("mean", () => {
test("for a normal distribution", () => { test("for a normal distribution", () => {
let result = GenericDist_GenericOperation.run(~env, #fromDist(#toFloat(#Mean), normalDist)) let result = GenericDist_GenericOperation.run(~env, FromDist(ToFloat(#Mean), normalDist))
expect(result)->toEqual(Float(5.0)) expect(result)->toEqual(Float(5.0))
}) })
}) })
@ -37,8 +37,8 @@ describe("mean", () => {
describe("mixture", () => { describe("mixture", () => {
test("on two normal distributions", () => { test("on two normal distributions", () => {
let result = let result =
run(#mixture([(normalDist10, 0.5), (normalDist20, 0.5)])) run(Mixture([(normalDist10, 0.5), (normalDist20, 0.5)]))
->outputMap(#fromDist(#toFloat(#Mean))) ->outputMap(FromDist(ToFloat(#Mean)))
->toFloat ->toFloat
->toExt ->toExt
expect(result)->toBeCloseTo(15.28) expect(result)->toBeCloseTo(15.28)
@ -48,8 +48,8 @@ describe("mixture", () => {
describe("toPointSet", () => { describe("toPointSet", () => {
test("on symbolic normal distribution", () => { test("on symbolic normal distribution", () => {
let result = let result =
run(#fromDist(#toDist(#toPointSet), normalDist)) run(FromDist(ToDist(ToPointSet), normalDist))
->outputMap(#fromDist(#toFloat(#Mean))) ->outputMap(FromDist(ToFloat(#Mean)))
->toFloat ->toFloat
->toExt ->toExt
expect(result)->toBeCloseTo(5.09) expect(result)->toBeCloseTo(5.09)
@ -57,18 +57,18 @@ describe("toPointSet", () => {
test("on sample set distribution with under 4 points", () => { test("on sample set distribution with under 4 points", () => {
let result = let result =
run(#fromDist(#toDist(#toPointSet), SampleSet([0.0, 1.0, 2.0, 3.0])))->outputMap( run(FromDist(ToDist(ToPointSet), SampleSet([0.0, 1.0, 2.0, 3.0])))->outputMap(
#fromDist(#toFloat(#Mean)), FromDist(ToFloat(#Mean)),
) )
expect(result)->toEqual(GenDistError(Other("Converting sampleSet to pointSet failed"))) expect(result)->toEqual(GenDistError(Other("Converting sampleSet to pointSet failed")))
}) })
Skip.test("on sample set", () => { Skip.test("on sample set", () => {
let result = let result =
run(#fromDist(#toDist(#toPointSet), normalDist)) run(FromDist(ToDist(ToPointSet), normalDist))
->outputMap(#fromDist(#toDist(#toSampleSet(1000)))) ->outputMap(FromDist(ToDist(ToSampleSet(1000))))
->outputMap(#fromDist(#toDist(#toPointSet))) ->outputMap(FromDist(ToDist(ToPointSet)))
->outputMap(#fromDist(#toFloat(#Mean))) ->outputMap(FromDist(ToFloat(#Mean)))
->toFloat ->toFloat
->toExt ->toExt
expect(result)->toBeCloseTo(5.09) expect(result)->toBeCloseTo(5.09)

View File

@ -59,4 +59,4 @@ let mixture: (
array<(t, float)>, array<(t, float)>,
~scaleMultiplyFn: scaleMultiplyFn, ~scaleMultiplyFn: scaleMultiplyFn,
~pointwiseAddFn: pointwiseAddFn, ~pointwiseAddFn: pointwiseAddFn,
) => result<t, error> ) => result<t, error>

View File

@ -70,14 +70,14 @@ let rec run = (~env, functionCallInfo: functionCallInfo): outputType => {
} }
let toPointSetFn = r => { let toPointSetFn = r => {
switch reCall(~functionCallInfo=#fromDist(#toDist(#toPointSet), r), ()) { switch reCall(~functionCallInfo=FromDist(ToDist(ToPointSet), r), ()) {
| Dist(PointSet(p)) => Ok(p) | Dist(PointSet(p)) => Ok(p)
| e => Error(OutputLocal.toErrorOrUnreachable(e)) | e => Error(OutputLocal.toErrorOrUnreachable(e))
} }
} }
let toSampleSetFn = r => { let toSampleSetFn = r => {
switch reCall(~functionCallInfo=#fromDist(#toDist(#toSampleSet(sampleCount)), r), ()) { switch reCall(~functionCallInfo=FromDist(ToDist(ToSampleSet(sampleCount)), r), ()) {
| Dist(SampleSet(p)) => Ok(p) | Dist(SampleSet(p)) => Ok(p)
| e => Error(OutputLocal.toErrorOrUnreachable(e)) | e => Error(OutputLocal.toErrorOrUnreachable(e))
} }
@ -85,51 +85,51 @@ let rec run = (~env, functionCallInfo: functionCallInfo): outputType => {
let scaleMultiply = (r, weight) => let scaleMultiply = (r, weight) =>
reCall( reCall(
~functionCallInfo=#fromDist(#toDistCombination(#Pointwise, #Multiply, #Float(weight)), r), ~functionCallInfo=FromDist(ToDistCombination(Pointwise, #Multiply, #Float(weight)), r),
(), (),
)->OutputLocal.toDistR )->OutputLocal.toDistR
let pointwiseAdd = (r1, r2) => let pointwiseAdd = (r1, r2) =>
reCall( reCall(
~functionCallInfo=#fromDist(#toDistCombination(#Pointwise, #Add, #Dist(r2)), r1), ~functionCallInfo=FromDist(ToDistCombination(Pointwise, #Add, #Dist(r2)), r1),
(), (),
)->OutputLocal.toDistR )->OutputLocal.toDistR
let fromDistFn = (subFnName: GenericDist_Types.Operation.fromDist, dist: genericDist) => let fromDistFn = (subFnName: GenericDist_Types.Operation.fromDist, dist: genericDist) =>
switch subFnName { switch subFnName {
| #toFloat(distToFloatOperation) => | ToFloat(distToFloatOperation) =>
GenericDist.toFloatOperation(dist, ~toPointSetFn, ~distToFloatOperation) GenericDist.toFloatOperation(dist, ~toPointSetFn, ~distToFloatOperation)
->E.R2.fmap(r => Float(r)) ->E.R2.fmap(r => Float(r))
->OutputLocal.fromResult ->OutputLocal.fromResult
| #toString => dist->GenericDist.toString->String | ToString => dist->GenericDist.toString->String
| #toDist(#inspect) => { | ToDist(Inspect) => {
Js.log2("Console log requested: ", dist) Js.log2("Console log requested: ", dist)
Dist(dist) Dist(dist)
} }
| #toDist(#normalize) => dist->GenericDist.normalize->Dist | ToDist(Normalize) => dist->GenericDist.normalize->Dist
| #toDist(#truncate(leftCutoff, rightCutoff)) => | ToDist(Truncate(leftCutoff, rightCutoff)) =>
GenericDist.truncate(~toPointSetFn, ~leftCutoff, ~rightCutoff, dist, ()) GenericDist.truncate(~toPointSetFn, ~leftCutoff, ~rightCutoff, dist, ())
->E.R2.fmap(r => Dist(r)) ->E.R2.fmap(r => Dist(r))
->OutputLocal.fromResult ->OutputLocal.fromResult
| #toDist(#toPointSet) => | ToDist(ToSampleSet(n)) =>
dist->GenericDist.sampleN(n)->E.R2.fmap(r => Dist(SampleSet(r)))->OutputLocal.fromResult
| ToDist(ToPointSet) =>
dist dist
->GenericDist.toPointSet(~xyPointLength, ~sampleCount) ->GenericDist.toPointSet(~xyPointLength, ~sampleCount)
->E.R2.fmap(r => Dist(PointSet(r))) ->E.R2.fmap(r => Dist(PointSet(r)))
->OutputLocal.fromResult ->OutputLocal.fromResult
| #toDist(#toSampleSet(n)) => | ToDistCombination(Algebraic, _, #Float(_)) => GenDistError(NotYetImplemented)
dist->GenericDist.sampleN(n)->E.R2.fmap(r => Dist(SampleSet(r)))->OutputLocal.fromResult | ToDistCombination(Algebraic, arithmeticOperation, #Dist(t2)) =>
| #toDistCombination(#Algebraic, _, #Float(_)) => GenDistError(NotYetImplemented)
| #toDistCombination(#Algebraic, arithmeticOperation, #Dist(t2)) =>
dist dist
->GenericDist.algebraicCombination(~toPointSetFn, ~toSampleSetFn, ~arithmeticOperation, ~t2) ->GenericDist.algebraicCombination(~toPointSetFn, ~toSampleSetFn, ~arithmeticOperation, ~t2)
->E.R2.fmap(r => Dist(r)) ->E.R2.fmap(r => Dist(r))
->OutputLocal.fromResult ->OutputLocal.fromResult
| #toDistCombination(#Pointwise, arithmeticOperation, #Dist(t2)) => | ToDistCombination(Pointwise, arithmeticOperation, #Dist(t2)) =>
dist dist
->GenericDist.pointwiseCombination(~toPointSetFn, ~arithmeticOperation, ~t2) ->GenericDist.pointwiseCombination(~toPointSetFn, ~arithmeticOperation, ~t2)
->E.R2.fmap(r => Dist(r)) ->E.R2.fmap(r => Dist(r))
->OutputLocal.fromResult ->OutputLocal.fromResult
| #toDistCombination(#Pointwise, arithmeticOperation, #Float(float)) => | ToDistCombination(Pointwise, arithmeticOperation, #Float(float)) =>
dist dist
->GenericDist.pointwiseCombinationFloat(~toPointSetFn, ~arithmeticOperation, ~float) ->GenericDist.pointwiseCombinationFloat(~toPointSetFn, ~arithmeticOperation, ~float)
->E.R2.fmap(r => Dist(r)) ->E.R2.fmap(r => Dist(r))
@ -137,10 +137,10 @@ let rec run = (~env, functionCallInfo: functionCallInfo): outputType => {
} }
switch functionCallInfo { switch functionCallInfo {
| #fromDist(subFnName, dist) => fromDistFn(subFnName, dist) | FromDist(subFnName, dist) => fromDistFn(subFnName, dist)
| #fromFloat(subFnName, float) => | FromFloat(subFnName, float) =>
reCall(~functionCallInfo=#fromDist(subFnName, GenericDist.fromFloat(float)), ()) reCall(~functionCallInfo=FromDist(subFnName, GenericDist.fromFloat(float)), ())
| #mixture(dists) => | Mixture(dists) =>
dists dists
->GenericDist.mixture(~scaleMultiplyFn=scaleMultiply, ~pointwiseAddFn=pointwiseAdd) ->GenericDist.mixture(~scaleMultiplyFn=scaleMultiply, ~pointwiseAddFn=pointwiseAdd)
->E.R2.fmap(r => Dist(r)) ->E.R2.fmap(r => Dist(r))
@ -148,9 +148,8 @@ let rec run = (~env, functionCallInfo: functionCallInfo): outputType => {
} }
} }
let runFromDist = (~env, ~functionCallInfo, dist) => run(~env, #fromDist(functionCallInfo, dist)) let runFromDist = (~env, ~functionCallInfo, dist) => run(~env, FromDist(functionCallInfo, dist))
let runFromFloat = (~env, ~functionCallInfo, float) => let runFromFloat = (~env, ~functionCallInfo, float) => run(~env, FromFloat(functionCallInfo, float))
run(~env, #fromFloat(functionCallInfo, float))
module Output = { module Output = {
include OutputLocal include OutputLocal
@ -161,11 +160,11 @@ module Output = {
functionCallInfo: GenericDist_Types.Operation.singleParamaterFunction, functionCallInfo: GenericDist_Types.Operation.singleParamaterFunction,
): outputType => { ): outputType => {
let newFnCall: result<functionCallInfo, error> = switch (functionCallInfo, input) { let newFnCall: result<functionCallInfo, error> = switch (functionCallInfo, input) {
| (#fromDist(fromDist), Dist(o)) => Ok(#fromDist(fromDist, o)) | (FromDist(fromDist), Dist(o)) => Ok(FromDist(fromDist, o))
| (#fromFloat(fromDist), Float(o)) => Ok(#fromFloat(fromDist, o)) | (FromFloat(fromDist), Float(o)) => Ok(FromFloat(fromDist, o))
| (_, GenDistError(r)) => Error(r) | (_, GenDistError(r)) => Error(r)
| (#fromDist(_), _) => Error(Other("Expected dist, got something else")) | (FromDist(_), _) => Error(Other("Expected dist, got something else"))
| (#fromFloat(_), _) => Error(Other("Expected float, got something else")) | (FromFloat(_), _) => Error(Other("Expected float, got something else"))
} }
newFnCall->E.R2.fmap(run(~env))->OutputLocal.fromResult newFnCall->E.R2.fmap(run(~env))->OutputLocal.fromResult
} }

View File

@ -10,10 +10,9 @@ type error =
| Other(string) | Other(string)
module Operation = { module Operation = {
type direction = [ type direction =
| #Algebraic | Algebraic
| #Pointwise | Pointwise
]
type arithmeticOperation = [ type arithmeticOperation = [
| #Add | #Add
@ -42,51 +41,50 @@ module Operation = {
| #Sample | #Sample
] ]
type toDist = [ type toDist =
| #normalize | Normalize
| #toPointSet | ToPointSet
| #toSampleSet(int) | ToSampleSet(int)
| #truncate(option<float>, option<float>) | Truncate(option<float>, option<float>)
| #inspect | Inspect
]
type toFloatArray = [ type toFloatArray = Sample(int)
| #Sample(int)
]
type fromDist = [ type fromDist =
| #toFloat(toFloat) | ToFloat(toFloat)
| #toDist(toDist) | ToDist(toDist)
| #toDistCombination(direction, arithmeticOperation, [#Dist(genericDist) | #Float(float)]) | ToDistCombination(direction, arithmeticOperation, [#Dist(genericDist) | #Float(float)])
| #toString | ToString
]
type singleParamaterFunction = [ type singleParamaterFunction =
| #fromDist(fromDist) | FromDist(fromDist)
| #fromFloat(fromDist) | FromFloat(fromDist)
]
type genericFunctionCallInfo = [ type genericFunctionCallInfo =
| #fromDist(fromDist, genericDist) | FromDist(fromDist, genericDist)
| #fromFloat(fromDist, float) | FromFloat(fromDist, float)
| #mixture(array<(genericDist, float)>) | Mixture(array<(genericDist, float)>)
]
//TODO: Should support all genericFunctionCallInfo types let distCallToString = (distFunction: fromDist): string =>
let toString = (distFunction: fromDist): string =>
switch distFunction { switch distFunction {
| #toFloat(#Cdf(r)) => `cdf(${E.Float.toFixed(r)})` | ToFloat(#Cdf(r)) => `cdf(${E.Float.toFixed(r)})`
| #toFloat(#Inv(r)) => `inv(${E.Float.toFixed(r)})` | ToFloat(#Inv(r)) => `inv(${E.Float.toFixed(r)})`
| #toFloat(#Mean) => `mean` | ToFloat(#Mean) => `mean`
| #toFloat(#Pdf(r)) => `pdf(${E.Float.toFixed(r)})` | ToFloat(#Pdf(r)) => `pdf(${E.Float.toFixed(r)})`
| #toFloat(#Sample) => `sample` | ToFloat(#Sample) => `sample`
| #toDist(#normalize) => `normalize` | ToDist(Normalize) => `normalize`
| #toDist(#toPointSet) => `toPointSet` | ToDist(ToPointSet) => `toPointSet`
| #toDist(#toSampleSet(r)) => `toSampleSet(${E.I.toString(r)})` | ToDist(ToSampleSet(r)) => `toSampleSet(${E.I.toString(r)})`
| #toDist(#truncate(_, _)) => `truncate` | ToDist(Truncate(_, _)) => `truncate`
| #toDist(#inspect) => `inspect` | ToDist(Inspect) => `inspect`
| #toString => `toString` | ToString => `toString`
| #toDistCombination(#Algebraic, _, _) => `algebraic` | ToDistCombination(Algebraic, _, _) => `algebraic`
| #toDistCombination(#Pointwise, _, _) => `pointwise` | ToDistCombination(Pointwise, _, _) => `pointwise`
} }
}
let toString = (d: genericFunctionCallInfo): string =>
switch d {
| FromDist(f, _) | FromFloat(f, _) => distCallToString(f)
| Mixture(_) => `mixture`
}
}