squiggle/packages/squiggle-lang/src/rescript/pointSet/Shape.res

208 lines
5.8 KiB
Plaintext
Raw Normal View History

2022-01-29 22:43:08 +00:00
open Distributions
2022-02-15 20:47:33 +00:00
type t = PointSetTypes.shape
2022-01-29 22:43:08 +00:00
let mapToAll = ((fn1, fn2, fn3), t: t) =>
switch t {
| Mixed(m) => fn1(m)
| Discrete(m) => fn2(m)
| Continuous(m) => fn3(m)
}
let fmap = ((fn1, fn2, fn3), t: t): t =>
switch t {
| Mixed(m) => Mixed(fn1(m))
| Discrete(m) => Discrete(fn2(m))
| Continuous(m) => Continuous(fn3(m))
}
let toMixed = mapToAll((
m => m,
d =>
Mixed.make(
~integralSumCache=d.integralSumCache,
~integralCache=d.integralCache,
~discrete=d,
~continuous=Continuous.empty,
),
c =>
Mixed.make(
~integralSumCache=c.integralSumCache,
~integralCache=c.integralCache,
~discrete=Discrete.empty,
~continuous=c,
),
))
let combineAlgebraically = (op: ExpressionTypes.algebraicOperation, t1: t, t2: t): t =>
switch (t1, t2) {
| (Continuous(m1), Continuous(m2)) =>
Continuous.combineAlgebraically(op, m1, m2) |> Continuous.T.toShape
| (Continuous(m1), Discrete(m2))
| (Discrete(m2), Continuous(m1)) =>
Continuous.combineAlgebraicallyWithDiscrete(op, m1, m2) |> Continuous.T.toShape
| (Discrete(m1), Discrete(m2)) => Discrete.combineAlgebraically(op, m1, m2) |> Discrete.T.toShape
| (m1, m2) => Mixed.combineAlgebraically(op, toMixed(m1), toMixed(m2)) |> Mixed.T.toShape
}
let combinePointwise = (
~integralSumCachesFn: (float, float) => option<float>=(_, _) => None,
~integralCachesFn: (
2022-02-15 20:47:33 +00:00
PointSetTypes.continuousShape,
PointSetTypes.continuousShape,
) => option<PointSetTypes.continuousShape>=(_, _) => None,
2022-01-29 22:43:08 +00:00
fn,
t1: t,
t2: t,
) =>
switch (t1, t2) {
| (Continuous(m1), Continuous(m2)) =>
2022-02-15 20:47:33 +00:00
PointSetTypes.Continuous(
2022-01-29 22:43:08 +00:00
Continuous.combinePointwise(~integralSumCachesFn, ~integralCachesFn, fn, m1, m2),
)
| (Discrete(m1), Discrete(m2)) =>
2022-02-15 20:47:33 +00:00
PointSetTypes.Discrete(
2022-01-29 22:43:08 +00:00
Discrete.combinePointwise(~integralSumCachesFn, ~integralCachesFn, fn, m1, m2),
)
| (m1, m2) =>
2022-02-15 20:47:33 +00:00
PointSetTypes.Mixed(
2022-01-29 22:43:08 +00:00
Mixed.combinePointwise(~integralSumCachesFn, ~integralCachesFn, fn, toMixed(m1), toMixed(m2)),
)
}
module T = Dist({
2022-02-15 20:47:33 +00:00
type t = PointSetTypes.shape
type integral = PointSetTypes.continuousShape
2022-01-29 22:43:08 +00:00
let xToY = (f: float) => mapToAll((Mixed.T.xToY(f), Discrete.T.xToY(f), Continuous.T.xToY(f)))
let toShape = (t: t) => t
let toContinuous = t => None
let toDiscrete = t => None
let downsample = (i, t) =>
fmap((Mixed.T.downsample(i), Discrete.T.downsample(i), Continuous.T.downsample(i)), t)
let truncate = (leftCutoff, rightCutoff, t): t =>
fmap(
(
Mixed.T.truncate(leftCutoff, rightCutoff),
Discrete.T.truncate(leftCutoff, rightCutoff),
Continuous.T.truncate(leftCutoff, rightCutoff),
),
t,
)
let toDiscreteProbabilityMassFraction = t => 0.0
let normalize = fmap((Mixed.T.normalize, Discrete.T.normalize, Continuous.T.normalize))
let updateIntegralCache = (integralCache, t: t): t =>
fmap(
(
Mixed.T.updateIntegralCache(integralCache),
Discrete.T.updateIntegralCache(integralCache),
Continuous.T.updateIntegralCache(integralCache),
),
t,
)
let toContinuous = mapToAll((
Mixed.T.toContinuous,
Discrete.T.toContinuous,
Continuous.T.toContinuous,
))
let toDiscrete = mapToAll((Mixed.T.toDiscrete, Discrete.T.toDiscrete, Continuous.T.toDiscrete))
let toDiscreteProbabilityMassFraction = mapToAll((
Mixed.T.toDiscreteProbabilityMassFraction,
Discrete.T.toDiscreteProbabilityMassFraction,
Continuous.T.toDiscreteProbabilityMassFraction,
))
let minX = mapToAll((Mixed.T.minX, Discrete.T.minX, Continuous.T.minX))
let integral = mapToAll((
Mixed.T.Integral.get,
Discrete.T.Integral.get,
Continuous.T.Integral.get,
))
let integralEndY = mapToAll((
Mixed.T.Integral.sum,
Discrete.T.Integral.sum,
Continuous.T.Integral.sum,
))
let integralXtoY = f =>
mapToAll((Mixed.T.Integral.xToY(f), Discrete.T.Integral.xToY(f), Continuous.T.Integral.xToY(f)))
let integralYtoX = f =>
mapToAll((Mixed.T.Integral.yToX(f), Discrete.T.Integral.yToX(f), Continuous.T.Integral.yToX(f)))
let maxX = mapToAll((Mixed.T.maxX, Discrete.T.maxX, Continuous.T.maxX))
let mapY = (
~integralSumCacheFn=previousIntegralSum => None,
~integralCacheFn=previousIntegral => None,
~fn,
) =>
fmap((
Mixed.T.mapY(~integralSumCacheFn, ~integralCacheFn, ~fn),
Discrete.T.mapY(~integralSumCacheFn, ~integralCacheFn, ~fn),
Continuous.T.mapY(~integralSumCacheFn, ~integralCacheFn, ~fn),
))
let mean = (t: t): float =>
switch t {
| Mixed(m) => Mixed.T.mean(m)
| Discrete(m) => Discrete.T.mean(m)
| Continuous(m) => Continuous.T.mean(m)
}
let variance = (t: t): float =>
switch t {
| Mixed(m) => Mixed.T.variance(m)
| Discrete(m) => Discrete.T.variance(m)
| Continuous(m) => Continuous.T.variance(m)
}
})
let pdf = (f: float, t: t) => {
2022-02-15 20:47:33 +00:00
let mixedPoint: PointSetTypes.mixedPoint = T.xToY(f, t)
2022-01-29 22:43:08 +00:00
mixedPoint.continuous +. mixedPoint.discrete
}
let inv = T.Integral.yToX
let cdf = T.Integral.xToY
let doN = (n, fn) => {
let items = Belt.Array.make(n, 0.0)
for x in 0 to n - 1 {
let _ = Belt.Array.set(items, x, fn())
}
items
}
let sample = (t: t): float => {
let randomItem = Random.float(1.)
let bar = t |> T.Integral.yToX(randomItem)
bar
}
let isFloat = (t: t) =>
switch t {
| Discrete({xyShape: {xs: [_], ys: [1.0]}}) => true
| _ => false
}
let sampleNRendered = (n, dist) => {
let integralCache = T.Integral.get(dist)
let distWithUpdatedIntegralCache = T.updateIntegralCache(Some(integralCache), dist)
doN(n, () => sample(distWithUpdatedIntegralCache))
}
let operate = (distToFloatOp: ExpressionTypes.distToFloatOperation, s): float =>
switch distToFloatOp {
| #Pdf(f) => pdf(f, s)
| #Cdf(f) => pdf(f, s)
| #Inv(f) => inv(f, s)
| #Sample => sample(s)
| #Mean => T.mean(s)
}