Added mixture fn for generic distributions

This commit is contained in:
Ozzie Gooen 2022-03-27 17:37:27 -04:00
parent b70e8e02e1
commit 3f678e24a1
2 changed files with 47 additions and 9 deletions

View File

@ -5,13 +5,15 @@ type toPointSetFn = genericDist => result<PointSetTypes.pointSetDist, error>
type toSampleSetFn = genericDist => result<array<float>, error>
type t = genericDist
let sampleN = (n, t: t) =>
let sampleN = (n, t: t) =>
switch t {
| #PointSet(r) => Ok(PointSetDist.sampleNRendered(n, r))
| #Symbolic(r) => Ok(SymbolicDist.T.sampleN(n, r))
| #SampleSet(_) => Error(GenericDist_Types.NotYetImplemented)
}
let fromFloat = (f: float) => #Symbolic(SymbolicDist.Float.make(f))
let toString = (t: t) =>
switch t {
| #PointSet(_) => "Point Set Distribution"
@ -19,15 +21,14 @@ let toString = (t: t) =>
| #SampleSet(_) => "Sample Set Distribution"
}
let normalize = (t: t) =>
let normalize = (t: t) =>
switch t {
| #PointSet(r) => #PointSet(PointSetDist.T.normalize(r))
| #Symbolic(_) => t
| #SampleSet(_) => t
}
// let isNormalized = (t:t) =>
// let isNormalized = (t:t) =>
let operationToFloat = (toPointSet: toPointSetFn, fnName, t: genericDist): result<float, error> => {
let symbolicSolution = switch t {
@ -214,3 +215,20 @@ let pointwiseCombinationFloat = (
})
} |> E.R.fmap(r => #PointSet(r))
}
let mixture = (
scaleMultiply: (genericDist, float) => result<genericDist, error>,
pointwiseAdd: (genericDist, genericDist) => result<genericDist, error>,
values: array<(genericDist, float)>,
) => {
let properlyWeightedValues =
values |> E.A.fmap(((dist, weight)) => scaleMultiply(dist, weight)) |> E.A.R.firstErrorOrOpen
properlyWeightedValues |> E.R.bind(_, values => {
values
|> Js.Array.sliceFrom(1)
|> E.A.fold_left(
(acc, x) => E.R.bind(acc, acc => pointwiseAdd(acc, x)),
Ok(E.A.unsafe_get(values, 0)),
)
})
}

View File

@ -29,11 +29,20 @@ let fromResult = (r: result<outputType, error>): outputType =>
| Error(e) => #Error(e)
}
let outputToDistResult = (b: outputType): result<genericDist, error> =>
switch b {
| #Dist(r) => Ok(r)
| #Error(r) => Error(r)
| _ => Error(ImpossiblePath)
}
let rec run = (extra, fnName: operation): outputType => {
let {sampleCount, xyPointLength} = extra
let reCall = (~extra=extra, ~fnName=fnName, ()) => {
run(extra, fnName)
}
let toPointSet = r => {
switch reCall(~fnName=#fromDist(#toDist(#toPointSet), r), ()) {
| #Dist(#PointSet(p)) => Ok(p)
@ -41,6 +50,7 @@ let rec run = (extra, fnName: operation): outputType => {
| _ => Error(ImpossiblePath)
}
}
let toSampleSet = r => {
switch reCall(~fnName=#fromDist(#toDist(#toSampleSet(sampleCount)), r), ()) {
| #Dist(#SampleSet(p)) => Ok(p)
@ -49,6 +59,18 @@ let rec run = (extra, fnName: operation): outputType => {
}
}
let scaleMultiply = (r, weight) =>
reCall(
~fnName=#fromDist(#toDistCombination(#Pointwise, #Multiply, #Float(weight)), r),
(),
) |> outputToDistResult
let pointwiseAdd = (r1, r2) =>
reCall(
~fnName=#fromDist(#toDistCombination(#Pointwise, #Add, #Dist(r2)), r1),
(),
) |> outputToDistResult
let fromDistFn = (subFn: GenericDist_Types.Operation.fromDist, dist: genericDist) =>
switch subFn {
| #toFloat(fnName) =>
@ -89,10 +111,8 @@ let rec run = (extra, fnName: operation): outputType => {
switch fnName {
| #fromDist(subFn, dist) => fromDistFn(subFn, dist)
| #fromFloat(subFn, float) => reCall(
~fnName=#fromDist(subFn, #Symbolic(SymbolicDist.Float.make(float))),
(),
)
| _ => #Error(NotYetImplemented)
| #fromFloat(subFn, float) => reCall(~fnName=#fromDist(subFn, GenericDist.fromFloat(float)), ())
| #mixture(dists) =>
GenericDist.mixture(scaleMultiply, pointwiseAdd, dists) |> E.R.fmap(r => #Dist(r)) |> fromResult
}
}