Added mixture fn for generic distributions
This commit is contained in:
parent
b70e8e02e1
commit
3f678e24a1
|
@ -5,13 +5,15 @@ type toPointSetFn = genericDist => result<PointSetTypes.pointSetDist, error>
|
|||
type toSampleSetFn = genericDist => result<array<float>, error>
|
||||
type t = genericDist
|
||||
|
||||
let sampleN = (n, t: t) =>
|
||||
let sampleN = (n, t: t) =>
|
||||
switch t {
|
||||
| #PointSet(r) => Ok(PointSetDist.sampleNRendered(n, r))
|
||||
| #Symbolic(r) => Ok(SymbolicDist.T.sampleN(n, r))
|
||||
| #SampleSet(_) => Error(GenericDist_Types.NotYetImplemented)
|
||||
}
|
||||
|
||||
let fromFloat = (f: float) => #Symbolic(SymbolicDist.Float.make(f))
|
||||
|
||||
let toString = (t: t) =>
|
||||
switch t {
|
||||
| #PointSet(_) => "Point Set Distribution"
|
||||
|
@ -19,15 +21,14 @@ let toString = (t: t) =>
|
|||
| #SampleSet(_) => "Sample Set Distribution"
|
||||
}
|
||||
|
||||
let normalize = (t: t) =>
|
||||
let normalize = (t: t) =>
|
||||
switch t {
|
||||
| #PointSet(r) => #PointSet(PointSetDist.T.normalize(r))
|
||||
| #Symbolic(_) => t
|
||||
| #SampleSet(_) => t
|
||||
}
|
||||
|
||||
// let isNormalized = (t:t) =>
|
||||
|
||||
// let isNormalized = (t:t) =>
|
||||
|
||||
let operationToFloat = (toPointSet: toPointSetFn, fnName, t: genericDist): result<float, error> => {
|
||||
let symbolicSolution = switch t {
|
||||
|
@ -214,3 +215,20 @@ let pointwiseCombinationFloat = (
|
|||
})
|
||||
} |> E.R.fmap(r => #PointSet(r))
|
||||
}
|
||||
|
||||
let mixture = (
|
||||
scaleMultiply: (genericDist, float) => result<genericDist, error>,
|
||||
pointwiseAdd: (genericDist, genericDist) => result<genericDist, error>,
|
||||
values: array<(genericDist, float)>,
|
||||
) => {
|
||||
let properlyWeightedValues =
|
||||
values |> E.A.fmap(((dist, weight)) => scaleMultiply(dist, weight)) |> E.A.R.firstErrorOrOpen
|
||||
properlyWeightedValues |> E.R.bind(_, values => {
|
||||
values
|
||||
|> Js.Array.sliceFrom(1)
|
||||
|> E.A.fold_left(
|
||||
(acc, x) => E.R.bind(acc, acc => pointwiseAdd(acc, x)),
|
||||
Ok(E.A.unsafe_get(values, 0)),
|
||||
)
|
||||
})
|
||||
}
|
||||
|
|
|
@ -29,11 +29,20 @@ let fromResult = (r: result<outputType, error>): outputType =>
|
|||
| Error(e) => #Error(e)
|
||||
}
|
||||
|
||||
let outputToDistResult = (b: outputType): result<genericDist, error> =>
|
||||
switch b {
|
||||
| #Dist(r) => Ok(r)
|
||||
| #Error(r) => Error(r)
|
||||
| _ => Error(ImpossiblePath)
|
||||
}
|
||||
|
||||
let rec run = (extra, fnName: operation): outputType => {
|
||||
let {sampleCount, xyPointLength} = extra
|
||||
|
||||
let reCall = (~extra=extra, ~fnName=fnName, ()) => {
|
||||
run(extra, fnName)
|
||||
}
|
||||
|
||||
let toPointSet = r => {
|
||||
switch reCall(~fnName=#fromDist(#toDist(#toPointSet), r), ()) {
|
||||
| #Dist(#PointSet(p)) => Ok(p)
|
||||
|
@ -41,6 +50,7 @@ let rec run = (extra, fnName: operation): outputType => {
|
|||
| _ => Error(ImpossiblePath)
|
||||
}
|
||||
}
|
||||
|
||||
let toSampleSet = r => {
|
||||
switch reCall(~fnName=#fromDist(#toDist(#toSampleSet(sampleCount)), r), ()) {
|
||||
| #Dist(#SampleSet(p)) => Ok(p)
|
||||
|
@ -49,6 +59,18 @@ let rec run = (extra, fnName: operation): outputType => {
|
|||
}
|
||||
}
|
||||
|
||||
let scaleMultiply = (r, weight) =>
|
||||
reCall(
|
||||
~fnName=#fromDist(#toDistCombination(#Pointwise, #Multiply, #Float(weight)), r),
|
||||
(),
|
||||
) |> outputToDistResult
|
||||
|
||||
let pointwiseAdd = (r1, r2) =>
|
||||
reCall(
|
||||
~fnName=#fromDist(#toDistCombination(#Pointwise, #Add, #Dist(r2)), r1),
|
||||
(),
|
||||
) |> outputToDistResult
|
||||
|
||||
let fromDistFn = (subFn: GenericDist_Types.Operation.fromDist, dist: genericDist) =>
|
||||
switch subFn {
|
||||
| #toFloat(fnName) =>
|
||||
|
@ -89,10 +111,8 @@ let rec run = (extra, fnName: operation): outputType => {
|
|||
|
||||
switch fnName {
|
||||
| #fromDist(subFn, dist) => fromDistFn(subFn, dist)
|
||||
| #fromFloat(subFn, float) => reCall(
|
||||
~fnName=#fromDist(subFn, #Symbolic(SymbolicDist.Float.make(float))),
|
||||
(),
|
||||
)
|
||||
| _ => #Error(NotYetImplemented)
|
||||
| #fromFloat(subFn, float) => reCall(~fnName=#fromDist(subFn, GenericDist.fromFloat(float)), ())
|
||||
| #mixture(dists) =>
|
||||
GenericDist.mixture(scaleMultiply, pointwiseAdd, dists) |> E.R.fmap(r => #Dist(r)) |> fromResult
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue
Block a user