Converted most of Operation to not be polymorphic
This commit is contained in:
parent
680726e8b0
commit
15534b10ce
|
@ -22,14 +22,14 @@ let toExt: option<'a> => 'a = E.O.toExt(
|
||||||
|
|
||||||
describe("normalize", () => {
|
describe("normalize", () => {
|
||||||
test("has no impact on normal dist", () => {
|
test("has no impact on normal dist", () => {
|
||||||
let result = run(#fromDist(#toDist(#normalize), normalDist))
|
let result = run(FromDist(ToDist(Normalize), normalDist))
|
||||||
expect(result)->toEqual(Dist(normalDist))
|
expect(result)->toEqual(Dist(normalDist))
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|
||||||
describe("mean", () => {
|
describe("mean", () => {
|
||||||
test("for a normal distribution", () => {
|
test("for a normal distribution", () => {
|
||||||
let result = GenericDist_GenericOperation.run(~env, #fromDist(#toFloat(#Mean), normalDist))
|
let result = GenericDist_GenericOperation.run(~env, FromDist(ToFloat(#Mean), normalDist))
|
||||||
expect(result)->toEqual(Float(5.0))
|
expect(result)->toEqual(Float(5.0))
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
@ -37,8 +37,8 @@ describe("mean", () => {
|
||||||
describe("mixture", () => {
|
describe("mixture", () => {
|
||||||
test("on two normal distributions", () => {
|
test("on two normal distributions", () => {
|
||||||
let result =
|
let result =
|
||||||
run(#mixture([(normalDist10, 0.5), (normalDist20, 0.5)]))
|
run(Mixture([(normalDist10, 0.5), (normalDist20, 0.5)]))
|
||||||
->outputMap(#fromDist(#toFloat(#Mean)))
|
->outputMap(FromDist(ToFloat(#Mean)))
|
||||||
->toFloat
|
->toFloat
|
||||||
->toExt
|
->toExt
|
||||||
expect(result)->toBeCloseTo(15.28)
|
expect(result)->toBeCloseTo(15.28)
|
||||||
|
@ -48,8 +48,8 @@ describe("mixture", () => {
|
||||||
describe("toPointSet", () => {
|
describe("toPointSet", () => {
|
||||||
test("on symbolic normal distribution", () => {
|
test("on symbolic normal distribution", () => {
|
||||||
let result =
|
let result =
|
||||||
run(#fromDist(#toDist(#toPointSet), normalDist))
|
run(FromDist(ToDist(ToPointSet), normalDist))
|
||||||
->outputMap(#fromDist(#toFloat(#Mean)))
|
->outputMap(FromDist(ToFloat(#Mean)))
|
||||||
->toFloat
|
->toFloat
|
||||||
->toExt
|
->toExt
|
||||||
expect(result)->toBeCloseTo(5.09)
|
expect(result)->toBeCloseTo(5.09)
|
||||||
|
@ -57,18 +57,18 @@ describe("toPointSet", () => {
|
||||||
|
|
||||||
test("on sample set distribution with under 4 points", () => {
|
test("on sample set distribution with under 4 points", () => {
|
||||||
let result =
|
let result =
|
||||||
run(#fromDist(#toDist(#toPointSet), SampleSet([0.0, 1.0, 2.0, 3.0])))->outputMap(
|
run(FromDist(ToDist(ToPointSet), SampleSet([0.0, 1.0, 2.0, 3.0])))->outputMap(
|
||||||
#fromDist(#toFloat(#Mean)),
|
FromDist(ToFloat(#Mean)),
|
||||||
)
|
)
|
||||||
expect(result)->toEqual(GenDistError(Other("Converting sampleSet to pointSet failed")))
|
expect(result)->toEqual(GenDistError(Other("Converting sampleSet to pointSet failed")))
|
||||||
})
|
})
|
||||||
|
|
||||||
Skip.test("on sample set", () => {
|
Skip.test("on sample set", () => {
|
||||||
let result =
|
let result =
|
||||||
run(#fromDist(#toDist(#toPointSet), normalDist))
|
run(FromDist(ToDist(ToPointSet), normalDist))
|
||||||
->outputMap(#fromDist(#toDist(#toSampleSet(1000))))
|
->outputMap(FromDist(ToDist(ToSampleSet(1000))))
|
||||||
->outputMap(#fromDist(#toDist(#toPointSet)))
|
->outputMap(FromDist(ToDist(ToPointSet)))
|
||||||
->outputMap(#fromDist(#toFloat(#Mean)))
|
->outputMap(FromDist(ToFloat(#Mean)))
|
||||||
->toFloat
|
->toFloat
|
||||||
->toExt
|
->toExt
|
||||||
expect(result)->toBeCloseTo(5.09)
|
expect(result)->toBeCloseTo(5.09)
|
||||||
|
|
|
@ -59,4 +59,4 @@ let mixture: (
|
||||||
array<(t, float)>,
|
array<(t, float)>,
|
||||||
~scaleMultiplyFn: scaleMultiplyFn,
|
~scaleMultiplyFn: scaleMultiplyFn,
|
||||||
~pointwiseAddFn: pointwiseAddFn,
|
~pointwiseAddFn: pointwiseAddFn,
|
||||||
) => result<t, error>
|
) => result<t, error>
|
|
@ -70,14 +70,14 @@ let rec run = (~env, functionCallInfo: functionCallInfo): outputType => {
|
||||||
}
|
}
|
||||||
|
|
||||||
let toPointSetFn = r => {
|
let toPointSetFn = r => {
|
||||||
switch reCall(~functionCallInfo=#fromDist(#toDist(#toPointSet), r), ()) {
|
switch reCall(~functionCallInfo=FromDist(ToDist(ToPointSet), r), ()) {
|
||||||
| Dist(PointSet(p)) => Ok(p)
|
| Dist(PointSet(p)) => Ok(p)
|
||||||
| e => Error(OutputLocal.toErrorOrUnreachable(e))
|
| e => Error(OutputLocal.toErrorOrUnreachable(e))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
let toSampleSetFn = r => {
|
let toSampleSetFn = r => {
|
||||||
switch reCall(~functionCallInfo=#fromDist(#toDist(#toSampleSet(sampleCount)), r), ()) {
|
switch reCall(~functionCallInfo=FromDist(ToDist(ToSampleSet(sampleCount)), r), ()) {
|
||||||
| Dist(SampleSet(p)) => Ok(p)
|
| Dist(SampleSet(p)) => Ok(p)
|
||||||
| e => Error(OutputLocal.toErrorOrUnreachable(e))
|
| e => Error(OutputLocal.toErrorOrUnreachable(e))
|
||||||
}
|
}
|
||||||
|
@ -85,51 +85,51 @@ let rec run = (~env, functionCallInfo: functionCallInfo): outputType => {
|
||||||
|
|
||||||
let scaleMultiply = (r, weight) =>
|
let scaleMultiply = (r, weight) =>
|
||||||
reCall(
|
reCall(
|
||||||
~functionCallInfo=#fromDist(#toDistCombination(#Pointwise, #Multiply, #Float(weight)), r),
|
~functionCallInfo=FromDist(ToDistCombination(Pointwise, #Multiply, #Float(weight)), r),
|
||||||
(),
|
(),
|
||||||
)->OutputLocal.toDistR
|
)->OutputLocal.toDistR
|
||||||
|
|
||||||
let pointwiseAdd = (r1, r2) =>
|
let pointwiseAdd = (r1, r2) =>
|
||||||
reCall(
|
reCall(
|
||||||
~functionCallInfo=#fromDist(#toDistCombination(#Pointwise, #Add, #Dist(r2)), r1),
|
~functionCallInfo=FromDist(ToDistCombination(Pointwise, #Add, #Dist(r2)), r1),
|
||||||
(),
|
(),
|
||||||
)->OutputLocal.toDistR
|
)->OutputLocal.toDistR
|
||||||
|
|
||||||
let fromDistFn = (subFnName: GenericDist_Types.Operation.fromDist, dist: genericDist) =>
|
let fromDistFn = (subFnName: GenericDist_Types.Operation.fromDist, dist: genericDist) =>
|
||||||
switch subFnName {
|
switch subFnName {
|
||||||
| #toFloat(distToFloatOperation) =>
|
| ToFloat(distToFloatOperation) =>
|
||||||
GenericDist.toFloatOperation(dist, ~toPointSetFn, ~distToFloatOperation)
|
GenericDist.toFloatOperation(dist, ~toPointSetFn, ~distToFloatOperation)
|
||||||
->E.R2.fmap(r => Float(r))
|
->E.R2.fmap(r => Float(r))
|
||||||
->OutputLocal.fromResult
|
->OutputLocal.fromResult
|
||||||
| #toString => dist->GenericDist.toString->String
|
| ToString => dist->GenericDist.toString->String
|
||||||
| #toDist(#inspect) => {
|
| ToDist(Inspect) => {
|
||||||
Js.log2("Console log requested: ", dist)
|
Js.log2("Console log requested: ", dist)
|
||||||
Dist(dist)
|
Dist(dist)
|
||||||
}
|
}
|
||||||
| #toDist(#normalize) => dist->GenericDist.normalize->Dist
|
| ToDist(Normalize) => dist->GenericDist.normalize->Dist
|
||||||
| #toDist(#truncate(leftCutoff, rightCutoff)) =>
|
| ToDist(Truncate(leftCutoff, rightCutoff)) =>
|
||||||
GenericDist.truncate(~toPointSetFn, ~leftCutoff, ~rightCutoff, dist, ())
|
GenericDist.truncate(~toPointSetFn, ~leftCutoff, ~rightCutoff, dist, ())
|
||||||
->E.R2.fmap(r => Dist(r))
|
->E.R2.fmap(r => Dist(r))
|
||||||
->OutputLocal.fromResult
|
->OutputLocal.fromResult
|
||||||
| #toDist(#toPointSet) =>
|
| ToDist(ToSampleSet(n)) =>
|
||||||
|
dist->GenericDist.sampleN(n)->E.R2.fmap(r => Dist(SampleSet(r)))->OutputLocal.fromResult
|
||||||
|
| ToDist(ToPointSet) =>
|
||||||
dist
|
dist
|
||||||
->GenericDist.toPointSet(~xyPointLength, ~sampleCount)
|
->GenericDist.toPointSet(~xyPointLength, ~sampleCount)
|
||||||
->E.R2.fmap(r => Dist(PointSet(r)))
|
->E.R2.fmap(r => Dist(PointSet(r)))
|
||||||
->OutputLocal.fromResult
|
->OutputLocal.fromResult
|
||||||
| #toDist(#toSampleSet(n)) =>
|
| ToDistCombination(Algebraic, _, #Float(_)) => GenDistError(NotYetImplemented)
|
||||||
dist->GenericDist.sampleN(n)->E.R2.fmap(r => Dist(SampleSet(r)))->OutputLocal.fromResult
|
| ToDistCombination(Algebraic, arithmeticOperation, #Dist(t2)) =>
|
||||||
| #toDistCombination(#Algebraic, _, #Float(_)) => GenDistError(NotYetImplemented)
|
|
||||||
| #toDistCombination(#Algebraic, arithmeticOperation, #Dist(t2)) =>
|
|
||||||
dist
|
dist
|
||||||
->GenericDist.algebraicCombination(~toPointSetFn, ~toSampleSetFn, ~arithmeticOperation, ~t2)
|
->GenericDist.algebraicCombination(~toPointSetFn, ~toSampleSetFn, ~arithmeticOperation, ~t2)
|
||||||
->E.R2.fmap(r => Dist(r))
|
->E.R2.fmap(r => Dist(r))
|
||||||
->OutputLocal.fromResult
|
->OutputLocal.fromResult
|
||||||
| #toDistCombination(#Pointwise, arithmeticOperation, #Dist(t2)) =>
|
| ToDistCombination(Pointwise, arithmeticOperation, #Dist(t2)) =>
|
||||||
dist
|
dist
|
||||||
->GenericDist.pointwiseCombination(~toPointSetFn, ~arithmeticOperation, ~t2)
|
->GenericDist.pointwiseCombination(~toPointSetFn, ~arithmeticOperation, ~t2)
|
||||||
->E.R2.fmap(r => Dist(r))
|
->E.R2.fmap(r => Dist(r))
|
||||||
->OutputLocal.fromResult
|
->OutputLocal.fromResult
|
||||||
| #toDistCombination(#Pointwise, arithmeticOperation, #Float(float)) =>
|
| ToDistCombination(Pointwise, arithmeticOperation, #Float(float)) =>
|
||||||
dist
|
dist
|
||||||
->GenericDist.pointwiseCombinationFloat(~toPointSetFn, ~arithmeticOperation, ~float)
|
->GenericDist.pointwiseCombinationFloat(~toPointSetFn, ~arithmeticOperation, ~float)
|
||||||
->E.R2.fmap(r => Dist(r))
|
->E.R2.fmap(r => Dist(r))
|
||||||
|
@ -137,10 +137,10 @@ let rec run = (~env, functionCallInfo: functionCallInfo): outputType => {
|
||||||
}
|
}
|
||||||
|
|
||||||
switch functionCallInfo {
|
switch functionCallInfo {
|
||||||
| #fromDist(subFnName, dist) => fromDistFn(subFnName, dist)
|
| FromDist(subFnName, dist) => fromDistFn(subFnName, dist)
|
||||||
| #fromFloat(subFnName, float) =>
|
| FromFloat(subFnName, float) =>
|
||||||
reCall(~functionCallInfo=#fromDist(subFnName, GenericDist.fromFloat(float)), ())
|
reCall(~functionCallInfo=FromDist(subFnName, GenericDist.fromFloat(float)), ())
|
||||||
| #mixture(dists) =>
|
| Mixture(dists) =>
|
||||||
dists
|
dists
|
||||||
->GenericDist.mixture(~scaleMultiplyFn=scaleMultiply, ~pointwiseAddFn=pointwiseAdd)
|
->GenericDist.mixture(~scaleMultiplyFn=scaleMultiply, ~pointwiseAddFn=pointwiseAdd)
|
||||||
->E.R2.fmap(r => Dist(r))
|
->E.R2.fmap(r => Dist(r))
|
||||||
|
@ -148,9 +148,8 @@ let rec run = (~env, functionCallInfo: functionCallInfo): outputType => {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
let runFromDist = (~env, ~functionCallInfo, dist) => run(~env, #fromDist(functionCallInfo, dist))
|
let runFromDist = (~env, ~functionCallInfo, dist) => run(~env, FromDist(functionCallInfo, dist))
|
||||||
let runFromFloat = (~env, ~functionCallInfo, float) =>
|
let runFromFloat = (~env, ~functionCallInfo, float) => run(~env, FromFloat(functionCallInfo, float))
|
||||||
run(~env, #fromFloat(functionCallInfo, float))
|
|
||||||
|
|
||||||
module Output = {
|
module Output = {
|
||||||
include OutputLocal
|
include OutputLocal
|
||||||
|
@ -161,11 +160,11 @@ module Output = {
|
||||||
functionCallInfo: GenericDist_Types.Operation.singleParamaterFunction,
|
functionCallInfo: GenericDist_Types.Operation.singleParamaterFunction,
|
||||||
): outputType => {
|
): outputType => {
|
||||||
let newFnCall: result<functionCallInfo, error> = switch (functionCallInfo, input) {
|
let newFnCall: result<functionCallInfo, error> = switch (functionCallInfo, input) {
|
||||||
| (#fromDist(fromDist), Dist(o)) => Ok(#fromDist(fromDist, o))
|
| (FromDist(fromDist), Dist(o)) => Ok(FromDist(fromDist, o))
|
||||||
| (#fromFloat(fromDist), Float(o)) => Ok(#fromFloat(fromDist, o))
|
| (FromFloat(fromDist), Float(o)) => Ok(FromFloat(fromDist, o))
|
||||||
| (_, GenDistError(r)) => Error(r)
|
| (_, GenDistError(r)) => Error(r)
|
||||||
| (#fromDist(_), _) => Error(Other("Expected dist, got something else"))
|
| (FromDist(_), _) => Error(Other("Expected dist, got something else"))
|
||||||
| (#fromFloat(_), _) => Error(Other("Expected float, got something else"))
|
| (FromFloat(_), _) => Error(Other("Expected float, got something else"))
|
||||||
}
|
}
|
||||||
newFnCall->E.R2.fmap(run(~env))->OutputLocal.fromResult
|
newFnCall->E.R2.fmap(run(~env))->OutputLocal.fromResult
|
||||||
}
|
}
|
||||||
|
|
|
@ -10,10 +10,9 @@ type error =
|
||||||
| Other(string)
|
| Other(string)
|
||||||
|
|
||||||
module Operation = {
|
module Operation = {
|
||||||
type direction = [
|
type direction =
|
||||||
| #Algebraic
|
| Algebraic
|
||||||
| #Pointwise
|
| Pointwise
|
||||||
]
|
|
||||||
|
|
||||||
type arithmeticOperation = [
|
type arithmeticOperation = [
|
||||||
| #Add
|
| #Add
|
||||||
|
@ -42,51 +41,50 @@ module Operation = {
|
||||||
| #Sample
|
| #Sample
|
||||||
]
|
]
|
||||||
|
|
||||||
type toDist = [
|
type toDist =
|
||||||
| #normalize
|
| Normalize
|
||||||
| #toPointSet
|
| ToPointSet
|
||||||
| #toSampleSet(int)
|
| ToSampleSet(int)
|
||||||
| #truncate(option<float>, option<float>)
|
| Truncate(option<float>, option<float>)
|
||||||
| #inspect
|
| Inspect
|
||||||
]
|
|
||||||
|
|
||||||
type toFloatArray = [
|
type toFloatArray = Sample(int)
|
||||||
| #Sample(int)
|
|
||||||
]
|
|
||||||
|
|
||||||
type fromDist = [
|
type fromDist =
|
||||||
| #toFloat(toFloat)
|
| ToFloat(toFloat)
|
||||||
| #toDist(toDist)
|
| ToDist(toDist)
|
||||||
| #toDistCombination(direction, arithmeticOperation, [#Dist(genericDist) | #Float(float)])
|
| ToDistCombination(direction, arithmeticOperation, [#Dist(genericDist) | #Float(float)])
|
||||||
| #toString
|
| ToString
|
||||||
]
|
|
||||||
|
|
||||||
type singleParamaterFunction = [
|
type singleParamaterFunction =
|
||||||
| #fromDist(fromDist)
|
| FromDist(fromDist)
|
||||||
| #fromFloat(fromDist)
|
| FromFloat(fromDist)
|
||||||
]
|
|
||||||
|
|
||||||
type genericFunctionCallInfo = [
|
type genericFunctionCallInfo =
|
||||||
| #fromDist(fromDist, genericDist)
|
| FromDist(fromDist, genericDist)
|
||||||
| #fromFloat(fromDist, float)
|
| FromFloat(fromDist, float)
|
||||||
| #mixture(array<(genericDist, float)>)
|
| Mixture(array<(genericDist, float)>)
|
||||||
]
|
|
||||||
|
|
||||||
//TODO: Should support all genericFunctionCallInfo types
|
let distCallToString = (distFunction: fromDist): string =>
|
||||||
let toString = (distFunction: fromDist): string =>
|
|
||||||
switch distFunction {
|
switch distFunction {
|
||||||
| #toFloat(#Cdf(r)) => `cdf(${E.Float.toFixed(r)})`
|
| ToFloat(#Cdf(r)) => `cdf(${E.Float.toFixed(r)})`
|
||||||
| #toFloat(#Inv(r)) => `inv(${E.Float.toFixed(r)})`
|
| ToFloat(#Inv(r)) => `inv(${E.Float.toFixed(r)})`
|
||||||
| #toFloat(#Mean) => `mean`
|
| ToFloat(#Mean) => `mean`
|
||||||
| #toFloat(#Pdf(r)) => `pdf(${E.Float.toFixed(r)})`
|
| ToFloat(#Pdf(r)) => `pdf(${E.Float.toFixed(r)})`
|
||||||
| #toFloat(#Sample) => `sample`
|
| ToFloat(#Sample) => `sample`
|
||||||
| #toDist(#normalize) => `normalize`
|
| ToDist(Normalize) => `normalize`
|
||||||
| #toDist(#toPointSet) => `toPointSet`
|
| ToDist(ToPointSet) => `toPointSet`
|
||||||
| #toDist(#toSampleSet(r)) => `toSampleSet(${E.I.toString(r)})`
|
| ToDist(ToSampleSet(r)) => `toSampleSet(${E.I.toString(r)})`
|
||||||
| #toDist(#truncate(_, _)) => `truncate`
|
| ToDist(Truncate(_, _)) => `truncate`
|
||||||
| #toDist(#inspect) => `inspect`
|
| ToDist(Inspect) => `inspect`
|
||||||
| #toString => `toString`
|
| ToString => `toString`
|
||||||
| #toDistCombination(#Algebraic, _, _) => `algebraic`
|
| ToDistCombination(Algebraic, _, _) => `algebraic`
|
||||||
| #toDistCombination(#Pointwise, _, _) => `pointwise`
|
| ToDistCombination(Pointwise, _, _) => `pointwise`
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
let toString = (d: genericFunctionCallInfo): string =>
|
||||||
|
switch d {
|
||||||
|
| FromDist(f, _) | FromFloat(f, _) => distCallToString(f)
|
||||||
|
| Mixture(_) => `mixture`
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue
Block a user