Added mixture fn for generic distributions
This commit is contained in:
parent
b70e8e02e1
commit
3f678e24a1
|
@ -12,6 +12,8 @@ let sampleN = (n, t: t) =>
|
||||||
| #SampleSet(_) => Error(GenericDist_Types.NotYetImplemented)
|
| #SampleSet(_) => Error(GenericDist_Types.NotYetImplemented)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
let fromFloat = (f: float) => #Symbolic(SymbolicDist.Float.make(f))
|
||||||
|
|
||||||
let toString = (t: t) =>
|
let toString = (t: t) =>
|
||||||
switch t {
|
switch t {
|
||||||
| #PointSet(_) => "Point Set Distribution"
|
| #PointSet(_) => "Point Set Distribution"
|
||||||
|
@ -28,7 +30,6 @@ let normalize = (t: t) =>
|
||||||
|
|
||||||
// let isNormalized = (t:t) =>
|
// let isNormalized = (t:t) =>
|
||||||
|
|
||||||
|
|
||||||
let operationToFloat = (toPointSet: toPointSetFn, fnName, t: genericDist): result<float, error> => {
|
let operationToFloat = (toPointSet: toPointSetFn, fnName, t: genericDist): result<float, error> => {
|
||||||
let symbolicSolution = switch t {
|
let symbolicSolution = switch t {
|
||||||
| #Symbolic(r) =>
|
| #Symbolic(r) =>
|
||||||
|
@ -214,3 +215,20 @@ let pointwiseCombinationFloat = (
|
||||||
})
|
})
|
||||||
} |> E.R.fmap(r => #PointSet(r))
|
} |> E.R.fmap(r => #PointSet(r))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
let mixture = (
|
||||||
|
scaleMultiply: (genericDist, float) => result<genericDist, error>,
|
||||||
|
pointwiseAdd: (genericDist, genericDist) => result<genericDist, error>,
|
||||||
|
values: array<(genericDist, float)>,
|
||||||
|
) => {
|
||||||
|
let properlyWeightedValues =
|
||||||
|
values |> E.A.fmap(((dist, weight)) => scaleMultiply(dist, weight)) |> E.A.R.firstErrorOrOpen
|
||||||
|
properlyWeightedValues |> E.R.bind(_, values => {
|
||||||
|
values
|
||||||
|
|> Js.Array.sliceFrom(1)
|
||||||
|
|> E.A.fold_left(
|
||||||
|
(acc, x) => E.R.bind(acc, acc => pointwiseAdd(acc, x)),
|
||||||
|
Ok(E.A.unsafe_get(values, 0)),
|
||||||
|
)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
|
@ -29,11 +29,20 @@ let fromResult = (r: result<outputType, error>): outputType =>
|
||||||
| Error(e) => #Error(e)
|
| Error(e) => #Error(e)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
let outputToDistResult = (b: outputType): result<genericDist, error> =>
|
||||||
|
switch b {
|
||||||
|
| #Dist(r) => Ok(r)
|
||||||
|
| #Error(r) => Error(r)
|
||||||
|
| _ => Error(ImpossiblePath)
|
||||||
|
}
|
||||||
|
|
||||||
let rec run = (extra, fnName: operation): outputType => {
|
let rec run = (extra, fnName: operation): outputType => {
|
||||||
let {sampleCount, xyPointLength} = extra
|
let {sampleCount, xyPointLength} = extra
|
||||||
|
|
||||||
let reCall = (~extra=extra, ~fnName=fnName, ()) => {
|
let reCall = (~extra=extra, ~fnName=fnName, ()) => {
|
||||||
run(extra, fnName)
|
run(extra, fnName)
|
||||||
}
|
}
|
||||||
|
|
||||||
let toPointSet = r => {
|
let toPointSet = r => {
|
||||||
switch reCall(~fnName=#fromDist(#toDist(#toPointSet), r), ()) {
|
switch reCall(~fnName=#fromDist(#toDist(#toPointSet), r), ()) {
|
||||||
| #Dist(#PointSet(p)) => Ok(p)
|
| #Dist(#PointSet(p)) => Ok(p)
|
||||||
|
@ -41,6 +50,7 @@ let rec run = (extra, fnName: operation): outputType => {
|
||||||
| _ => Error(ImpossiblePath)
|
| _ => Error(ImpossiblePath)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
let toSampleSet = r => {
|
let toSampleSet = r => {
|
||||||
switch reCall(~fnName=#fromDist(#toDist(#toSampleSet(sampleCount)), r), ()) {
|
switch reCall(~fnName=#fromDist(#toDist(#toSampleSet(sampleCount)), r), ()) {
|
||||||
| #Dist(#SampleSet(p)) => Ok(p)
|
| #Dist(#SampleSet(p)) => Ok(p)
|
||||||
|
@ -49,6 +59,18 @@ let rec run = (extra, fnName: operation): outputType => {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
let scaleMultiply = (r, weight) =>
|
||||||
|
reCall(
|
||||||
|
~fnName=#fromDist(#toDistCombination(#Pointwise, #Multiply, #Float(weight)), r),
|
||||||
|
(),
|
||||||
|
) |> outputToDistResult
|
||||||
|
|
||||||
|
let pointwiseAdd = (r1, r2) =>
|
||||||
|
reCall(
|
||||||
|
~fnName=#fromDist(#toDistCombination(#Pointwise, #Add, #Dist(r2)), r1),
|
||||||
|
(),
|
||||||
|
) |> outputToDistResult
|
||||||
|
|
||||||
let fromDistFn = (subFn: GenericDist_Types.Operation.fromDist, dist: genericDist) =>
|
let fromDistFn = (subFn: GenericDist_Types.Operation.fromDist, dist: genericDist) =>
|
||||||
switch subFn {
|
switch subFn {
|
||||||
| #toFloat(fnName) =>
|
| #toFloat(fnName) =>
|
||||||
|
@ -89,10 +111,8 @@ let rec run = (extra, fnName: operation): outputType => {
|
||||||
|
|
||||||
switch fnName {
|
switch fnName {
|
||||||
| #fromDist(subFn, dist) => fromDistFn(subFn, dist)
|
| #fromDist(subFn, dist) => fromDistFn(subFn, dist)
|
||||||
| #fromFloat(subFn, float) => reCall(
|
| #fromFloat(subFn, float) => reCall(~fnName=#fromDist(subFn, GenericDist.fromFloat(float)), ())
|
||||||
~fnName=#fromDist(subFn, #Symbolic(SymbolicDist.Float.make(float))),
|
| #mixture(dists) =>
|
||||||
(),
|
GenericDist.mixture(scaleMultiply, pointwiseAdd, dists) |> E.R.fmap(r => #Dist(r)) |> fromResult
|
||||||
)
|
|
||||||
| _ => #Error(NotYetImplemented)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue
Block a user