Added mixture fn for generic distributions

This commit is contained in:
Ozzie Gooen 2022-03-27 17:37:27 -04:00
parent b70e8e02e1
commit 3f678e24a1
2 changed files with 47 additions and 9 deletions

View File

@ -12,6 +12,8 @@ let sampleN = (n, t: t) =>
| #SampleSet(_) => Error(GenericDist_Types.NotYetImplemented) | #SampleSet(_) => Error(GenericDist_Types.NotYetImplemented)
} }
let fromFloat = (f: float) => #Symbolic(SymbolicDist.Float.make(f))
let toString = (t: t) => let toString = (t: t) =>
switch t { switch t {
| #PointSet(_) => "Point Set Distribution" | #PointSet(_) => "Point Set Distribution"
@ -28,7 +30,6 @@ let normalize = (t: t) =>
// let isNormalized = (t:t) => // let isNormalized = (t:t) =>
let operationToFloat = (toPointSet: toPointSetFn, fnName, t: genericDist): result<float, error> => { let operationToFloat = (toPointSet: toPointSetFn, fnName, t: genericDist): result<float, error> => {
let symbolicSolution = switch t { let symbolicSolution = switch t {
| #Symbolic(r) => | #Symbolic(r) =>
@ -214,3 +215,20 @@ let pointwiseCombinationFloat = (
}) })
} |> E.R.fmap(r => #PointSet(r)) } |> E.R.fmap(r => #PointSet(r))
} }
let mixture = (
scaleMultiply: (genericDist, float) => result<genericDist, error>,
pointwiseAdd: (genericDist, genericDist) => result<genericDist, error>,
values: array<(genericDist, float)>,
) => {
let properlyWeightedValues =
values |> E.A.fmap(((dist, weight)) => scaleMultiply(dist, weight)) |> E.A.R.firstErrorOrOpen
properlyWeightedValues |> E.R.bind(_, values => {
values
|> Js.Array.sliceFrom(1)
|> E.A.fold_left(
(acc, x) => E.R.bind(acc, acc => pointwiseAdd(acc, x)),
Ok(E.A.unsafe_get(values, 0)),
)
})
}

View File

@ -29,11 +29,20 @@ let fromResult = (r: result<outputType, error>): outputType =>
| Error(e) => #Error(e) | Error(e) => #Error(e)
} }
let outputToDistResult = (b: outputType): result<genericDist, error> =>
switch b {
| #Dist(r) => Ok(r)
| #Error(r) => Error(r)
| _ => Error(ImpossiblePath)
}
let rec run = (extra, fnName: operation): outputType => { let rec run = (extra, fnName: operation): outputType => {
let {sampleCount, xyPointLength} = extra let {sampleCount, xyPointLength} = extra
let reCall = (~extra=extra, ~fnName=fnName, ()) => { let reCall = (~extra=extra, ~fnName=fnName, ()) => {
run(extra, fnName) run(extra, fnName)
} }
let toPointSet = r => { let toPointSet = r => {
switch reCall(~fnName=#fromDist(#toDist(#toPointSet), r), ()) { switch reCall(~fnName=#fromDist(#toDist(#toPointSet), r), ()) {
| #Dist(#PointSet(p)) => Ok(p) | #Dist(#PointSet(p)) => Ok(p)
@ -41,6 +50,7 @@ let rec run = (extra, fnName: operation): outputType => {
| _ => Error(ImpossiblePath) | _ => Error(ImpossiblePath)
} }
} }
let toSampleSet = r => { let toSampleSet = r => {
switch reCall(~fnName=#fromDist(#toDist(#toSampleSet(sampleCount)), r), ()) { switch reCall(~fnName=#fromDist(#toDist(#toSampleSet(sampleCount)), r), ()) {
| #Dist(#SampleSet(p)) => Ok(p) | #Dist(#SampleSet(p)) => Ok(p)
@ -49,6 +59,18 @@ let rec run = (extra, fnName: operation): outputType => {
} }
} }
let scaleMultiply = (r, weight) =>
reCall(
~fnName=#fromDist(#toDistCombination(#Pointwise, #Multiply, #Float(weight)), r),
(),
) |> outputToDistResult
let pointwiseAdd = (r1, r2) =>
reCall(
~fnName=#fromDist(#toDistCombination(#Pointwise, #Add, #Dist(r2)), r1),
(),
) |> outputToDistResult
let fromDistFn = (subFn: GenericDist_Types.Operation.fromDist, dist: genericDist) => let fromDistFn = (subFn: GenericDist_Types.Operation.fromDist, dist: genericDist) =>
switch subFn { switch subFn {
| #toFloat(fnName) => | #toFloat(fnName) =>
@ -89,10 +111,8 @@ let rec run = (extra, fnName: operation): outputType => {
switch fnName { switch fnName {
| #fromDist(subFn, dist) => fromDistFn(subFn, dist) | #fromDist(subFn, dist) => fromDistFn(subFn, dist)
| #fromFloat(subFn, float) => reCall( | #fromFloat(subFn, float) => reCall(~fnName=#fromDist(subFn, GenericDist.fromFloat(float)), ())
~fnName=#fromDist(subFn, #Symbolic(SymbolicDist.Float.make(float))), | #mixture(dists) =>
(), GenericDist.mixture(scaleMultiply, pointwiseAdd, dists) |> E.R.fmap(r => #Dist(r)) |> fromResult
)
| _ => #Error(NotYetImplemented)
} }
} }