squiggle/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/PointSetDist.res

268 lines
8.1 KiB
Plaintext
Raw Normal View History

2022-01-29 22:43:08 +00:00
open Distributions
type t = PointSetTypes.pointSetDist
2022-02-17 14:50:43 +00:00
2022-01-29 22:43:08 +00:00
let mapToAll = ((fn1, fn2, fn3), t: t) =>
switch t {
| Mixed(m) => fn1(m)
| Discrete(m) => fn2(m)
| Continuous(m) => fn3(m)
}
let fmap = ((fn1, fn2, fn3), t: t): t =>
switch t {
| Mixed(m) => Mixed(fn1(m))
| Discrete(m) => Discrete(fn2(m))
| Continuous(m) => Continuous(fn3(m))
}
let fmapResult = ((fn1, fn2, fn3), t: t): result<t, 'e> =>
switch t {
| Mixed(m) => fn1(m)->E.R2.fmap(x => PointSetTypes.Mixed(x))
| Discrete(m) => fn2(m)->E.R2.fmap(x => PointSetTypes.Discrete(x))
| Continuous(m) => fn3(m)->E.R2.fmap(x => PointSetTypes.Continuous(x))
}
2022-01-29 22:43:08 +00:00
let toMixed = mapToAll((
m => m,
d =>
Mixed.make(
~integralSumCache=d.integralSumCache,
~integralCache=d.integralCache,
~discrete=d,
~continuous=Continuous.empty,
),
c =>
Mixed.make(
~integralSumCache=c.integralSumCache,
~integralCache=c.integralCache,
~discrete=Discrete.empty,
~continuous=c,
),
))
2022-03-28 11:56:20 +00:00
//TODO WARNING: The combineAlgebraicallyWithDiscrete will break for subtraction and division, like, discrete - continous
let combineAlgebraically = (op: Operation.convolutionOperation, t1: t, t2: t): t =>
2022-01-29 22:43:08 +00:00
switch (t1, t2) {
| (Continuous(m1), Continuous(m2)) =>
Continuous.combineAlgebraically(op, m1, m2) |> Continuous.T.toPointSetDist
| (Discrete(m1), Continuous(m2)) =>
Continuous.combineAlgebraicallyWithDiscrete(
op,
m2,
m1,
~discretePosition=First,
) |> Continuous.T.toPointSetDist
| (Continuous(m1), Discrete(m2)) =>
Continuous.combineAlgebraicallyWithDiscrete(
op,
m1,
m2,
~discretePosition=Second,
) |> Continuous.T.toPointSetDist
2022-03-26 20:56:56 +00:00
| (Discrete(m1), Discrete(m2)) =>
Discrete.combineAlgebraically(op, m1, m2) |> Discrete.T.toPointSetDist
| (m1, m2) => Mixed.combineAlgebraically(op, toMixed(m1), toMixed(m2)) |> Mixed.T.toPointSetDist
2022-01-29 22:43:08 +00:00
}
let combinePointwise = (
~integralSumCachesFn: (float, float) => option<float>=(_, _) => None,
~integralCachesFn: (
2022-02-15 20:47:33 +00:00
PointSetTypes.continuousShape,
PointSetTypes.continuousShape,
) => option<PointSetTypes.continuousShape>=(_, _) => None,
2022-04-23 18:35:49 +00:00
fn: (float, float) => result<float, Operation.Error.t>,
2022-01-29 22:43:08 +00:00
t1: t,
t2: t,
2022-04-23 18:35:49 +00:00
): result<PointSetTypes.pointSetDist, Operation.Error.t> =>
2022-01-29 22:43:08 +00:00
switch (t1, t2) {
| (Continuous(m1), Continuous(m2)) =>
2022-04-23 18:09:06 +00:00
Continuous.combinePointwise(
~integralSumCachesFn,
fn,
m1,
m2,
)->E.R2.fmap(x => PointSetTypes.Continuous(x))
2022-01-29 22:43:08 +00:00
| (Discrete(m1), Discrete(m2)) =>
2022-04-29 01:14:03 +00:00
Discrete.combinePointwise(
~integralSumCachesFn,
~fn,
2022-04-29 01:14:03 +00:00
m1,
m2,
)->E.R2.fmap(x => PointSetTypes.Discrete(x))
2022-01-29 22:43:08 +00:00
| (m1, m2) =>
2022-04-23 18:09:06 +00:00
Mixed.combinePointwise(
~integralSumCachesFn,
~integralCachesFn,
fn,
toMixed(m1),
toMixed(m2),
)->E.R2.fmap(x => PointSetTypes.Mixed(x))
2022-01-29 22:43:08 +00:00
}
module T = Dist({
type t = PointSetTypes.pointSetDist
2022-02-15 20:47:33 +00:00
type integral = PointSetTypes.continuousShape
2022-01-29 22:43:08 +00:00
let xToY = (f: float) => mapToAll((Mixed.T.xToY(f), Discrete.T.xToY(f), Continuous.T.xToY(f)))
let toPointSetDist = (t: t) => t
2022-01-29 22:43:08 +00:00
let downsample = (i, t) =>
fmap((Mixed.T.downsample(i), Discrete.T.downsample(i), Continuous.T.downsample(i)), t)
let truncate = (leftCutoff, rightCutoff, t): t =>
fmap(
(
Mixed.T.truncate(leftCutoff, rightCutoff),
Discrete.T.truncate(leftCutoff, rightCutoff),
Continuous.T.truncate(leftCutoff, rightCutoff),
),
t,
)
let normalize = fmap((Mixed.T.normalize, Discrete.T.normalize, Continuous.T.normalize))
let updateIntegralCache = (integralCache, t: t): t =>
fmap(
(
Mixed.T.updateIntegralCache(integralCache),
Discrete.T.updateIntegralCache(integralCache),
Continuous.T.updateIntegralCache(integralCache),
),
t,
)
let toContinuous = mapToAll((
Mixed.T.toContinuous,
Discrete.T.toContinuous,
Continuous.T.toContinuous,
))
let toDiscrete = mapToAll((Mixed.T.toDiscrete, Discrete.T.toDiscrete, Continuous.T.toDiscrete))
let toDiscreteProbabilityMassFraction = mapToAll((
Mixed.T.toDiscreteProbabilityMassFraction,
Discrete.T.toDiscreteProbabilityMassFraction,
Continuous.T.toDiscreteProbabilityMassFraction,
))
let minX = mapToAll((Mixed.T.minX, Discrete.T.minX, Continuous.T.minX))
let integral = mapToAll((
Mixed.T.Integral.get,
Discrete.T.Integral.get,
Continuous.T.Integral.get,
))
let integralEndY = mapToAll((
Mixed.T.Integral.sum,
Discrete.T.Integral.sum,
Continuous.T.Integral.sum,
))
let integralXtoY = f =>
mapToAll((Mixed.T.Integral.xToY(f), Discrete.T.Integral.xToY(f), Continuous.T.Integral.xToY(f)))
let integralYtoX = f =>
mapToAll((Mixed.T.Integral.yToX(f), Discrete.T.Integral.yToX(f), Continuous.T.Integral.yToX(f)))
let maxX = mapToAll((Mixed.T.maxX, Discrete.T.maxX, Continuous.T.maxX))
let mapY = (~integralSumCacheFn=_ => None, ~integralCacheFn=_ => None, ~fn: float => float): (
t => t
) =>
2022-01-29 22:43:08 +00:00
fmap((
Mixed.T.mapY(~integralSumCacheFn, ~integralCacheFn, ~fn),
Discrete.T.mapY(~integralSumCacheFn, ~integralCacheFn, ~fn),
Continuous.T.mapY(~integralSumCacheFn, ~integralCacheFn, ~fn),
))
let mapYResult = (
~integralSumCacheFn=_ => None,
~integralCacheFn=_ => None,
~fn: float => result<float, 'e>,
): (t => result<t, 'e>) =>
fmapResult((
Mixed.T.mapYResult(~integralSumCacheFn, ~integralCacheFn, ~fn),
Discrete.T.mapYResult(~integralSumCacheFn, ~integralCacheFn, ~fn),
Continuous.T.mapYResult(~integralSumCacheFn, ~integralCacheFn, ~fn),
))
2022-01-29 22:43:08 +00:00
let mean = (t: t): float =>
switch t {
| Mixed(m) => Mixed.T.mean(m)
| Discrete(m) => Discrete.T.mean(m)
| Continuous(m) => Continuous.T.mean(m)
}
let variance = (t: t): float =>
switch t {
| Mixed(m) => Mixed.T.variance(m)
| Discrete(m) => Discrete.T.variance(m)
| Continuous(m) => Continuous.T.variance(m)
}
2022-04-29 00:24:13 +00:00
// let klDivergence = (prediction: t, answer: t) =>
// switch (prediction, answer) {
// | (Continuous(t1), Continuous(t2)) => Continuous.T.klDivergence(t1, t2)
// | (Discrete(t1), Discrete(t2)) => Discrete.T.klDivergence(t1, t2)
// | (m1, m2) => Mixed.T.klDivergence(m1->toMixed, m2->toMixed)
// }
//
// let logScoreWithPointResolution = (~prediction: t, ~answer: float, ~prior: option<t>) => {
// switch (prior, prediction) {
// | (Some(Continuous(t1)), Continuous(t2)) =>
// Continuous.T.logScoreWithPointResolution(~prediction=t2, ~answer, ~prior=t1->Some)
// | (None, Continuous(t2)) =>
// Continuous.T.logScoreWithPointResolution(~prediction=t2, ~answer, ~prior=None)
// | _ => Error(Operation.NotYetImplemented)
// }
// }
2022-01-29 22:43:08 +00:00
})
let logScore = (args: PointSetDist_Scoring.scoreArgs): result<float, Operation.Error.t> =>
2022-05-25 12:17:45 +00:00
PointSetDist_Scoring.logScore(args, ~combineFn=combinePointwise, ~integrateFn=T.integralEndY)
2022-01-29 22:43:08 +00:00
let pdf = (f: float, t: t) => {
2022-02-15 20:47:33 +00:00
let mixedPoint: PointSetTypes.mixedPoint = T.xToY(f, t)
2022-01-29 22:43:08 +00:00
mixedPoint.continuous +. mixedPoint.discrete
}
let inv = T.Integral.yToX
let cdf = T.Integral.xToY
let doN = (n, fn) => {
let items = Belt.Array.make(n, 0.0)
for x in 0 to n - 1 {
let _ = Belt.Array.set(items, x, fn())
}
items
}
let sample = (t: t): float => {
let randomItem = Random.float(1.)
let bar = t |> T.Integral.yToX(randomItem)
bar
}
let isFloat = (t: t) =>
switch t {
| Discrete({xyShape: {xs: [_], ys: [1.0]}}) => true
| _ => false
}
let sampleNRendered = (n, dist) => {
let integralCache = T.Integral.get(dist)
let distWithUpdatedIntegralCache = T.updateIntegralCache(Some(integralCache), dist)
doN(n, () => sample(distWithUpdatedIntegralCache))
}
let operate = (distToFloatOp: Operation.distToFloatOperation, s): float =>
2022-01-29 22:43:08 +00:00
switch distToFloatOp {
| #Pdf(f) => pdf(f, s)
2022-03-22 00:08:16 +00:00
| #Cdf(f) => cdf(f, s)
2022-01-29 22:43:08 +00:00
| #Inv(f) => inv(f, s)
| #Sample => sample(s)
| #Mean => T.mean(s)
}
2022-04-23 13:56:47 +00:00
let toSparkline = (t: t, bucketCount): result<string, PointSetTypes.sparklineError> =>
T.toContinuous(t)
2022-04-09 02:55:06 +00:00
->E.O2.fmap(Continuous.downsampleEquallyOverX(bucketCount))
2022-04-23 13:56:47 +00:00
->E.O2.toResult(PointSetTypes.CannotSparklineDiscrete)
->E.R2.fmap(r => Continuous.getShape(r).ys->Sparklines.create())