Merge pull request #69 from QURIresearch/dist-generic-library
Generic Distribution Library: First Pass
This commit is contained in:
commit
bc1a73b0af
|
@ -0,0 +1,76 @@
|
|||
open Jest
|
||||
open Expect
|
||||
|
||||
let env: GenericDist_GenericOperation.env = {
|
||||
sampleCount: 100,
|
||||
xyPointLength: 100,
|
||||
}
|
||||
|
||||
let normalDist: GenericDist_Types.genericDist = Symbolic(#Normal({mean: 5.0, stdev: 2.0}))
|
||||
let normalDist10: GenericDist_Types.genericDist = Symbolic(#Normal({mean: 10.0, stdev: 2.0}))
|
||||
let normalDist20: GenericDist_Types.genericDist = Symbolic(#Normal({mean: 20.0, stdev: 2.0}))
|
||||
let uniformDist: GenericDist_Types.genericDist = Symbolic(#Uniform({low: 9.0, high: 10.0}))
|
||||
|
||||
let {toFloat, toDist, toString, toError} = module(GenericDist_GenericOperation.Output)
|
||||
let {run} = module(GenericDist_GenericOperation)
|
||||
let {fmap} = module(GenericDist_GenericOperation.Output)
|
||||
let run = run(~env)
|
||||
let outputMap = fmap(~env)
|
||||
let toExt: option<'a> => 'a = E.O.toExt(
|
||||
"Should be impossible to reach (This error is in test file)",
|
||||
)
|
||||
|
||||
describe("normalize", () => {
|
||||
test("has no impact on normal dist", () => {
|
||||
let result = run(FromDist(ToDist(Normalize), normalDist))
|
||||
expect(result)->toEqual(Dist(normalDist))
|
||||
})
|
||||
})
|
||||
|
||||
describe("mean", () => {
|
||||
test("for a normal distribution", () => {
|
||||
let result = GenericDist_GenericOperation.run(~env, FromDist(ToFloat(#Mean), normalDist))
|
||||
expect(result)->toEqual(Float(5.0))
|
||||
})
|
||||
})
|
||||
|
||||
describe("mixture", () => {
|
||||
test("on two normal distributions", () => {
|
||||
let result =
|
||||
run(Mixture([(normalDist10, 0.5), (normalDist20, 0.5)]))
|
||||
->outputMap(FromDist(ToFloat(#Mean)))
|
||||
->toFloat
|
||||
->toExt
|
||||
expect(result)->toBeCloseTo(15.28)
|
||||
})
|
||||
})
|
||||
|
||||
describe("toPointSet", () => {
|
||||
test("on symbolic normal distribution", () => {
|
||||
let result =
|
||||
run(FromDist(ToDist(ToPointSet), normalDist))
|
||||
->outputMap(FromDist(ToFloat(#Mean)))
|
||||
->toFloat
|
||||
->toExt
|
||||
expect(result)->toBeCloseTo(5.09)
|
||||
})
|
||||
|
||||
test("on sample set distribution with under 4 points", () => {
|
||||
let result =
|
||||
run(FromDist(ToDist(ToPointSet), SampleSet([0.0, 1.0, 2.0, 3.0])))->outputMap(
|
||||
FromDist(ToFloat(#Mean)),
|
||||
)
|
||||
expect(result)->toEqual(GenDistError(Other("Converting sampleSet to pointSet failed")))
|
||||
})
|
||||
|
||||
Skip.test("on sample set", () => {
|
||||
let result =
|
||||
run(FromDist(ToDist(ToPointSet), normalDist))
|
||||
->outputMap(FromDist(ToDist(ToSampleSet(1000))))
|
||||
->outputMap(FromDist(ToDist(ToPointSet)))
|
||||
->outputMap(FromDist(ToFloat(#Mean)))
|
||||
->toFloat
|
||||
->toExt
|
||||
expect(result)->toBeCloseTo(5.09)
|
||||
})
|
||||
})
|
272
packages/squiggle-lang/src/rescript/GenericDist/GenericDist.res
Normal file
272
packages/squiggle-lang/src/rescript/GenericDist/GenericDist.res
Normal file
|
@ -0,0 +1,272 @@
|
|||
//TODO: multimodal, add interface, test somehow, track performance, refactor sampleSet, refactor ASTEvaluator.res.
|
||||
type t = GenericDist_Types.genericDist
|
||||
type error = GenericDist_Types.error
|
||||
type toPointSetFn = t => result<PointSetTypes.pointSetDist, error>
|
||||
type toSampleSetFn = t => result<array<float>, error>
|
||||
type scaleMultiplyFn = (t, float) => result<t, error>
|
||||
type pointwiseAddFn = (t, t) => result<t, error>
|
||||
|
||||
let sampleN = (t: t, n) =>
|
||||
switch t {
|
||||
| PointSet(r) => Ok(PointSetDist.sampleNRendered(n, r))
|
||||
| Symbolic(r) => Ok(SymbolicDist.T.sampleN(n, r))
|
||||
| SampleSet(_) => Error(GenericDist_Types.NotYetImplemented)
|
||||
}
|
||||
|
||||
let fromFloat = (f: float): t => Symbolic(SymbolicDist.Float.make(f))
|
||||
|
||||
let toString = (t: t) =>
|
||||
switch t {
|
||||
| PointSet(_) => "Point Set Distribution"
|
||||
| Symbolic(r) => SymbolicDist.T.toString(r)
|
||||
| SampleSet(_) => "Sample Set Distribution"
|
||||
}
|
||||
|
||||
let normalize = (t: t): t =>
|
||||
switch t {
|
||||
| PointSet(r) => PointSet(PointSetDist.T.normalize(r))
|
||||
| Symbolic(_) => t
|
||||
| SampleSet(_) => t
|
||||
}
|
||||
|
||||
let toFloatOperation = (
|
||||
t,
|
||||
~toPointSetFn: toPointSetFn,
|
||||
~distToFloatOperation: Operation.distToFloatOperation,
|
||||
) => {
|
||||
let symbolicSolution = switch (t: t) {
|
||||
| Symbolic(r) =>
|
||||
switch SymbolicDist.T.operate(distToFloatOperation, r) {
|
||||
| Ok(f) => Some(f)
|
||||
| _ => None
|
||||
}
|
||||
| _ => None
|
||||
}
|
||||
|
||||
switch symbolicSolution {
|
||||
| Some(r) => Ok(r)
|
||||
| None => toPointSetFn(t)->E.R2.fmap(PointSetDist.operate(distToFloatOperation))
|
||||
}
|
||||
}
|
||||
|
||||
//Todo: If it's a pointSet, but the xyPointLenght is different from what it has, it should change.
|
||||
// This is tricky because the case of discrete distributions.
|
||||
// Also, change the outputXYPoints/pointSetDistLength details
|
||||
let toPointSet = (~xyPointLength, ~sampleCount, t): result<PointSetTypes.pointSetDist, error> => {
|
||||
switch (t: t) {
|
||||
| PointSet(pointSet) => Ok(pointSet)
|
||||
| Symbolic(r) => Ok(SymbolicDist.T.toPointSetDist(xyPointLength, r))
|
||||
| SampleSet(r) => {
|
||||
let response = SampleSet.toPointSetDist(
|
||||
~samples=r,
|
||||
~samplingInputs={
|
||||
sampleCount: sampleCount,
|
||||
outputXYPoints: xyPointLength,
|
||||
pointSetDistLength: xyPointLength,
|
||||
kernelWidth: None,
|
||||
},
|
||||
(),
|
||||
).pointSetDist
|
||||
switch response {
|
||||
| Some(r) => Ok(r)
|
||||
| None => Error(Other("Converting sampleSet to pointSet failed"))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
module Truncate = {
|
||||
let trySymbolicSimplification = (leftCutoff, rightCutoff, t: t): option<t> =>
|
||||
switch (leftCutoff, rightCutoff, t) {
|
||||
| (None, None, _) => None
|
||||
| (lc, rc, Symbolic(#Uniform(u))) if lc < rc =>
|
||||
Some(Symbolic(#Uniform(SymbolicDist.Uniform.truncate(lc, rc, u))))
|
||||
| _ => None
|
||||
}
|
||||
|
||||
let run = (
|
||||
t: t,
|
||||
~toPointSetFn: toPointSetFn,
|
||||
~leftCutoff=None: option<float>,
|
||||
~rightCutoff=None: option<float>,
|
||||
(),
|
||||
): result<t, error> => {
|
||||
let doesNotNeedCutoff = E.O.isNone(leftCutoff) && E.O.isNone(rightCutoff)
|
||||
if doesNotNeedCutoff {
|
||||
Ok(t)
|
||||
} else {
|
||||
switch trySymbolicSimplification(leftCutoff, rightCutoff, t) {
|
||||
| Some(r) => Ok(r)
|
||||
| None =>
|
||||
toPointSetFn(t)->E.R2.fmap(t => {
|
||||
GenericDist_Types.PointSet(PointSetDist.T.truncate(leftCutoff, rightCutoff, t))
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let truncate = Truncate.run
|
||||
|
||||
/* Given two random variables A and B, this returns the distribution
|
||||
of a new variable that is the result of the operation on A and B.
|
||||
For instance, normal(0, 1) + normal(1, 1) -> normal(1, 2).
|
||||
In general, this is implemented via convolution.
|
||||
|
||||
TODO: It would be useful to be able to pass in a paramater to get this to run either with convolution or monte carlo.
|
||||
*/
|
||||
module AlgebraicCombination = {
|
||||
let tryAnalyticalSimplification = (
|
||||
arithmeticOperation: GenericDist_Types.Operation.arithmeticOperation,
|
||||
t1: t,
|
||||
t2: t,
|
||||
): option<result<SymbolicDistTypes.symbolicDist, string>> =>
|
||||
switch (arithmeticOperation, t1, t2) {
|
||||
| (arithmeticOperation, Symbolic(d1), Symbolic(d2)) =>
|
||||
switch SymbolicDist.T.tryAnalyticalSimplification(d1, d2, arithmeticOperation) {
|
||||
| #AnalyticalSolution(symbolicDist) => Some(Ok(symbolicDist))
|
||||
| #Error(er) => Some(Error(er))
|
||||
| #NoSolution => None
|
||||
}
|
||||
| _ => None
|
||||
}
|
||||
|
||||
let runConvolution = (
|
||||
toPointSet: toPointSetFn,
|
||||
arithmeticOperation: GenericDist_Types.Operation.arithmeticOperation,
|
||||
t1: t,
|
||||
t2: t,
|
||||
) =>
|
||||
E.R.merge(toPointSet(t1), toPointSet(t2))->E.R2.fmap(((a, b)) =>
|
||||
PointSetDist.combineAlgebraically(arithmeticOperation, a, b)
|
||||
)
|
||||
|
||||
let runMonteCarlo = (
|
||||
toSampleSet: toSampleSetFn,
|
||||
arithmeticOperation: GenericDist_Types.Operation.arithmeticOperation,
|
||||
t1: t,
|
||||
t2: t,
|
||||
) => {
|
||||
let arithmeticOperation = Operation.Algebraic.toFn(arithmeticOperation)
|
||||
E.R.merge(toSampleSet(t1), toSampleSet(t2))->E.R2.fmap(((a, b)) => {
|
||||
Belt.Array.zip(a, b)->E.A2.fmap(((a, b)) => arithmeticOperation(a, b))
|
||||
})
|
||||
}
|
||||
|
||||
//I'm (Ozzie) really just guessing here, very little idea what's best
|
||||
let expectedConvolutionCost: t => int = x =>
|
||||
switch x {
|
||||
| Symbolic(#Float(_)) => 1
|
||||
| Symbolic(_) => 1000
|
||||
| PointSet(Discrete(m)) => m.xyShape->XYShape.T.length
|
||||
| PointSet(Mixed(_)) => 1000
|
||||
| PointSet(Continuous(_)) => 1000
|
||||
| _ => 1000
|
||||
}
|
||||
|
||||
let chooseConvolutionOrMonteCarlo = (t2: t, t1: t) =>
|
||||
expectedConvolutionCost(t1) * expectedConvolutionCost(t2) > 10000
|
||||
? #CalculateWithMonteCarlo
|
||||
: #CalculateWithConvolution
|
||||
|
||||
let run = (
|
||||
t1: t,
|
||||
~toPointSetFn: toPointSetFn,
|
||||
~toSampleSetFn: toSampleSetFn,
|
||||
~arithmeticOperation,
|
||||
~t2: t,
|
||||
): result<t, error> => {
|
||||
switch tryAnalyticalSimplification(arithmeticOperation, t1, t2) {
|
||||
| Some(Ok(symbolicDist)) => Ok(Symbolic(symbolicDist))
|
||||
| Some(Error(e)) => Error(Other(e))
|
||||
| None =>
|
||||
switch chooseConvolutionOrMonteCarlo(t1, t2) {
|
||||
| #CalculateWithMonteCarlo =>
|
||||
runMonteCarlo(
|
||||
toSampleSetFn,
|
||||
arithmeticOperation,
|
||||
t1,
|
||||
t2,
|
||||
)->E.R2.fmap(r => GenericDist_Types.SampleSet(r))
|
||||
| #CalculateWithConvolution =>
|
||||
runConvolution(
|
||||
toPointSetFn,
|
||||
arithmeticOperation,
|
||||
t1,
|
||||
t2,
|
||||
)->E.R2.fmap(r => GenericDist_Types.PointSet(r))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let algebraicCombination = AlgebraicCombination.run
|
||||
|
||||
//TODO: Add faster pointwiseCombine fn
|
||||
let pointwiseCombination = (
|
||||
t1: t,
|
||||
~toPointSetFn: toPointSetFn,
|
||||
~arithmeticOperation,
|
||||
~t2: t,
|
||||
): result<t, error> => {
|
||||
E.R.merge(toPointSetFn(t1), toPointSetFn(t2))
|
||||
->E.R2.fmap(((t1, t2)) =>
|
||||
PointSetDist.combinePointwise(
|
||||
GenericDist_Types.Operation.arithmeticToFn(arithmeticOperation),
|
||||
t1,
|
||||
t2,
|
||||
)
|
||||
)
|
||||
->E.R2.fmap(r => GenericDist_Types.PointSet(r))
|
||||
}
|
||||
|
||||
let pointwiseCombinationFloat = (
|
||||
t: t,
|
||||
~toPointSetFn: toPointSetFn,
|
||||
~arithmeticOperation: GenericDist_Types.Operation.arithmeticOperation,
|
||||
~float: float,
|
||||
): result<t, error> => {
|
||||
let m = switch arithmeticOperation {
|
||||
| #Add | #Subtract => Error(GenericDist_Types.DistributionVerticalShiftIsInvalid)
|
||||
| (#Multiply | #Divide | #Exponentiate | #Log) as arithmeticOperation =>
|
||||
toPointSetFn(t)->E.R2.fmap(t => {
|
||||
//TODO: Move to PointSet codebase
|
||||
let fn = (secondary, main) => Operation.Scale.toFn(arithmeticOperation, main, secondary)
|
||||
let integralSumCacheFn = Operation.Scale.toIntegralSumCacheFn(arithmeticOperation)
|
||||
let integralCacheFn = Operation.Scale.toIntegralCacheFn(arithmeticOperation)
|
||||
PointSetDist.T.mapY(
|
||||
~integralSumCacheFn=integralSumCacheFn(float),
|
||||
~integralCacheFn=integralCacheFn(float),
|
||||
~fn=fn(float),
|
||||
t,
|
||||
)
|
||||
})
|
||||
}
|
||||
m->E.R2.fmap(r => GenericDist_Types.PointSet(r))
|
||||
}
|
||||
|
||||
//Note: The result should always cumulatively sum to 1. This would be good to test.
|
||||
//Note: If the inputs are not normalized, this will return poor results. The weights probably refer to the post-normalized forms. It would be good to apply a catch to this.
|
||||
let mixture = (
|
||||
values: array<(t, float)>,
|
||||
~scaleMultiplyFn: scaleMultiplyFn,
|
||||
~pointwiseAddFn: pointwiseAddFn,
|
||||
) => {
|
||||
if E.A.length(values) == 0 {
|
||||
Error(GenericDist_Types.Other("mixture must have at least 1 element"))
|
||||
} else {
|
||||
let totalWeight = values->E.A2.fmap(E.Tuple2.second)->E.A.Floats.sum
|
||||
let properlyWeightedValues =
|
||||
values
|
||||
->E.A2.fmap(((dist, weight)) => scaleMultiplyFn(dist, weight /. totalWeight))
|
||||
->E.A.R.firstErrorOrOpen
|
||||
properlyWeightedValues->E.R.bind(values => {
|
||||
values
|
||||
|> Js.Array.sliceFrom(1)
|
||||
|> E.A.fold_left(
|
||||
(acc, x) => E.R.bind(acc, acc => pointwiseAddFn(acc, x)),
|
||||
Ok(E.A.unsafe_get(values, 0)),
|
||||
)
|
||||
})
|
||||
}
|
||||
}
|
|
@ -0,0 +1,62 @@
|
|||
type t = GenericDist_Types.genericDist
|
||||
type error = GenericDist_Types.error
|
||||
type toPointSetFn = t => result<PointSetTypes.pointSetDist, error>
|
||||
type toSampleSetFn = t => result<array<float>, error>
|
||||
type scaleMultiplyFn = (t, float) => result<t, error>
|
||||
type pointwiseAddFn = (t, t) => result<t, error>
|
||||
|
||||
let sampleN: (t, int) => result<array<float>, error>
|
||||
|
||||
let fromFloat: float => t
|
||||
|
||||
let toString: t => string
|
||||
|
||||
let normalize: t => t
|
||||
|
||||
let toFloatOperation: (
|
||||
t,
|
||||
~toPointSetFn: toPointSetFn,
|
||||
~distToFloatOperation: Operation.distToFloatOperation,
|
||||
) => result<float, error>
|
||||
|
||||
let toPointSet: (
|
||||
~xyPointLength: int,
|
||||
~sampleCount: int,
|
||||
t,
|
||||
) => result<PointSetTypes.pointSetDist, error>
|
||||
|
||||
let truncate: (
|
||||
t,
|
||||
~toPointSetFn: toPointSetFn,
|
||||
~leftCutoff: option<float>=?,
|
||||
~rightCutoff: option<float>=?,
|
||||
unit,
|
||||
) => result<t, error>
|
||||
|
||||
let algebraicCombination: (
|
||||
t,
|
||||
~toPointSetFn: toPointSetFn,
|
||||
~toSampleSetFn: toSampleSetFn,
|
||||
~arithmeticOperation: GenericDist_Types.Operation.arithmeticOperation,
|
||||
~t2: t,
|
||||
) => result<t, error>
|
||||
|
||||
let pointwiseCombination: (
|
||||
t,
|
||||
~toPointSetFn: toPointSetFn,
|
||||
~arithmeticOperation: GenericDist_Types.Operation.arithmeticOperation,
|
||||
~t2: t,
|
||||
) => result<t, error>
|
||||
|
||||
let pointwiseCombinationFloat: (
|
||||
t,
|
||||
~toPointSetFn: toPointSetFn,
|
||||
~arithmeticOperation: GenericDist_Types.Operation.arithmeticOperation,
|
||||
~float: float,
|
||||
) => result<t, error>
|
||||
|
||||
let mixture: (
|
||||
array<(t, float)>,
|
||||
~scaleMultiplyFn: scaleMultiplyFn,
|
||||
~pointwiseAddFn: pointwiseAddFn,
|
||||
) => result<t, error>
|
|
@ -0,0 +1,171 @@
|
|||
type functionCallInfo = GenericDist_Types.Operation.genericFunctionCallInfo
|
||||
type genericDist = GenericDist_Types.genericDist
|
||||
type error = GenericDist_Types.error
|
||||
|
||||
// TODO: It could be great to use a cache for some calculations (basically, do memoization). Also, better analytics/tracking could go a long way.
|
||||
|
||||
type env = {
|
||||
sampleCount: int,
|
||||
xyPointLength: int,
|
||||
}
|
||||
|
||||
type outputType =
|
||||
| Dist(GenericDist_Types.genericDist)
|
||||
| Float(float)
|
||||
| String(string)
|
||||
| GenDistError(GenericDist_Types.error)
|
||||
|
||||
/*
|
||||
We're going to add another function to this module later, so first define a
|
||||
local version, which is not exported.
|
||||
*/
|
||||
module OutputLocal = {
|
||||
type t = outputType
|
||||
|
||||
let toError = (t: outputType) =>
|
||||
switch t {
|
||||
| GenDistError(d) => Some(d)
|
||||
| _ => None
|
||||
}
|
||||
|
||||
let toErrorOrUnreachable = (t: t): error => t->toError->E.O2.default((Unreachable: error))
|
||||
|
||||
let toDistR = (t: t): result<genericDist, error> =>
|
||||
switch t {
|
||||
| Dist(r) => Ok(r)
|
||||
| e => Error(toErrorOrUnreachable(e))
|
||||
}
|
||||
|
||||
let toDist = (t: t) =>
|
||||
switch t {
|
||||
| Dist(d) => Some(d)
|
||||
| _ => None
|
||||
}
|
||||
|
||||
let toFloat = (t: t) =>
|
||||
switch t {
|
||||
| Float(d) => Some(d)
|
||||
| _ => None
|
||||
}
|
||||
|
||||
let toString = (t: t) =>
|
||||
switch t {
|
||||
| String(d) => Some(d)
|
||||
| _ => None
|
||||
}
|
||||
|
||||
//This is used to catch errors in other switch statements.
|
||||
let fromResult = (r: result<t, error>): outputType =>
|
||||
switch r {
|
||||
| Ok(t) => t
|
||||
| Error(e) => GenDistError(e)
|
||||
}
|
||||
}
|
||||
|
||||
let rec run = (~env, functionCallInfo: functionCallInfo): outputType => {
|
||||
let {sampleCount, xyPointLength} = env
|
||||
|
||||
let reCall = (~env=env, ~functionCallInfo=functionCallInfo, ()) => {
|
||||
run(~env, functionCallInfo)
|
||||
}
|
||||
|
||||
let toPointSetFn = r => {
|
||||
switch reCall(~functionCallInfo=FromDist(ToDist(ToPointSet), r), ()) {
|
||||
| Dist(PointSet(p)) => Ok(p)
|
||||
| e => Error(OutputLocal.toErrorOrUnreachable(e))
|
||||
}
|
||||
}
|
||||
|
||||
let toSampleSetFn = r => {
|
||||
switch reCall(~functionCallInfo=FromDist(ToDist(ToSampleSet(sampleCount)), r), ()) {
|
||||
| Dist(SampleSet(p)) => Ok(p)
|
||||
| e => Error(OutputLocal.toErrorOrUnreachable(e))
|
||||
}
|
||||
}
|
||||
|
||||
let scaleMultiply = (r, weight) =>
|
||||
reCall(
|
||||
~functionCallInfo=FromDist(ToDistCombination(Pointwise, #Multiply, #Float(weight)), r),
|
||||
(),
|
||||
)->OutputLocal.toDistR
|
||||
|
||||
let pointwiseAdd = (r1, r2) =>
|
||||
reCall(
|
||||
~functionCallInfo=FromDist(ToDistCombination(Pointwise, #Add, #Dist(r2)), r1),
|
||||
(),
|
||||
)->OutputLocal.toDistR
|
||||
|
||||
let fromDistFn = (subFnName: GenericDist_Types.Operation.fromDist, dist: genericDist) =>
|
||||
switch subFnName {
|
||||
| ToFloat(distToFloatOperation) =>
|
||||
GenericDist.toFloatOperation(dist, ~toPointSetFn, ~distToFloatOperation)
|
||||
->E.R2.fmap(r => Float(r))
|
||||
->OutputLocal.fromResult
|
||||
| ToString => dist->GenericDist.toString->String
|
||||
| ToDist(Inspect) => {
|
||||
Js.log2("Console log requested: ", dist)
|
||||
Dist(dist)
|
||||
}
|
||||
| ToDist(Normalize) => dist->GenericDist.normalize->Dist
|
||||
| ToDist(Truncate(leftCutoff, rightCutoff)) =>
|
||||
GenericDist.truncate(~toPointSetFn, ~leftCutoff, ~rightCutoff, dist, ())
|
||||
->E.R2.fmap(r => Dist(r))
|
||||
->OutputLocal.fromResult
|
||||
| ToDist(ToSampleSet(n)) =>
|
||||
dist->GenericDist.sampleN(n)->E.R2.fmap(r => Dist(SampleSet(r)))->OutputLocal.fromResult
|
||||
| ToDist(ToPointSet) =>
|
||||
dist
|
||||
->GenericDist.toPointSet(~xyPointLength, ~sampleCount)
|
||||
->E.R2.fmap(r => Dist(PointSet(r)))
|
||||
->OutputLocal.fromResult
|
||||
| ToDistCombination(Algebraic, _, #Float(_)) => GenDistError(NotYetImplemented)
|
||||
| ToDistCombination(Algebraic, arithmeticOperation, #Dist(t2)) =>
|
||||
dist
|
||||
->GenericDist.algebraicCombination(~toPointSetFn, ~toSampleSetFn, ~arithmeticOperation, ~t2)
|
||||
->E.R2.fmap(r => Dist(r))
|
||||
->OutputLocal.fromResult
|
||||
| ToDistCombination(Pointwise, arithmeticOperation, #Dist(t2)) =>
|
||||
dist
|
||||
->GenericDist.pointwiseCombination(~toPointSetFn, ~arithmeticOperation, ~t2)
|
||||
->E.R2.fmap(r => Dist(r))
|
||||
->OutputLocal.fromResult
|
||||
| ToDistCombination(Pointwise, arithmeticOperation, #Float(float)) =>
|
||||
dist
|
||||
->GenericDist.pointwiseCombinationFloat(~toPointSetFn, ~arithmeticOperation, ~float)
|
||||
->E.R2.fmap(r => Dist(r))
|
||||
->OutputLocal.fromResult
|
||||
}
|
||||
|
||||
switch functionCallInfo {
|
||||
| FromDist(subFnName, dist) => fromDistFn(subFnName, dist)
|
||||
| FromFloat(subFnName, float) =>
|
||||
reCall(~functionCallInfo=FromDist(subFnName, GenericDist.fromFloat(float)), ())
|
||||
| Mixture(dists) =>
|
||||
dists
|
||||
->GenericDist.mixture(~scaleMultiplyFn=scaleMultiply, ~pointwiseAddFn=pointwiseAdd)
|
||||
->E.R2.fmap(r => Dist(r))
|
||||
->OutputLocal.fromResult
|
||||
}
|
||||
}
|
||||
|
||||
let runFromDist = (~env, ~functionCallInfo, dist) => run(~env, FromDist(functionCallInfo, dist))
|
||||
let runFromFloat = (~env, ~functionCallInfo, float) => run(~env, FromFloat(functionCallInfo, float))
|
||||
|
||||
module Output = {
|
||||
include OutputLocal
|
||||
|
||||
let fmap = (
|
||||
~env,
|
||||
input: outputType,
|
||||
functionCallInfo: GenericDist_Types.Operation.singleParamaterFunction,
|
||||
): outputType => {
|
||||
let newFnCall: result<functionCallInfo, error> = switch (functionCallInfo, input) {
|
||||
| (FromDist(fromDist), Dist(o)) => Ok(FromDist(fromDist, o))
|
||||
| (FromFloat(fromDist), Float(o)) => Ok(FromFloat(fromDist, o))
|
||||
| (_, GenDistError(r)) => Error(r)
|
||||
| (FromDist(_), _) => Error(Other("Expected dist, got something else"))
|
||||
| (FromFloat(_), _) => Error(Other("Expected float, got something else"))
|
||||
}
|
||||
newFnCall->E.R2.fmap(run(~env))->OutputLocal.fromResult
|
||||
}
|
||||
}
|
|
@ -0,0 +1,32 @@
|
|||
type env = {
|
||||
sampleCount: int,
|
||||
xyPointLength: int,
|
||||
}
|
||||
|
||||
type outputType =
|
||||
| Dist(GenericDist_Types.genericDist)
|
||||
| Float(float)
|
||||
| String(string)
|
||||
| GenDistError(GenericDist_Types.error)
|
||||
|
||||
let run: (~env: env, GenericDist_Types.Operation.genericFunctionCallInfo) => outputType
|
||||
let runFromDist: (
|
||||
~env: env,
|
||||
~functionCallInfo: GenericDist_Types.Operation.fromDist,
|
||||
GenericDist_Types.genericDist,
|
||||
) => outputType
|
||||
let runFromFloat: (
|
||||
~env: env,
|
||||
~functionCallInfo: GenericDist_Types.Operation.fromDist,
|
||||
float,
|
||||
) => outputType
|
||||
|
||||
module Output: {
|
||||
type t = outputType
|
||||
let toDist: t => option<GenericDist_Types.genericDist>
|
||||
let toDistR: t => result<GenericDist_Types.genericDist, GenericDist_Types.error>
|
||||
let toFloat: t => option<float>
|
||||
let toString: t => option<string>
|
||||
let toError: t => option<GenericDist_Types.error>
|
||||
let fmap: (~env: env, t, GenericDist_Types.Operation.singleParamaterFunction) => t
|
||||
}
|
|
@ -0,0 +1,90 @@
|
|||
type genericDist =
|
||||
| PointSet(PointSetTypes.pointSetDist)
|
||||
| SampleSet(array<float>)
|
||||
| Symbolic(SymbolicDistTypes.symbolicDist)
|
||||
|
||||
type error =
|
||||
| NotYetImplemented
|
||||
| Unreachable
|
||||
| DistributionVerticalShiftIsInvalid
|
||||
| Other(string)
|
||||
|
||||
module Operation = {
|
||||
type direction =
|
||||
| Algebraic
|
||||
| Pointwise
|
||||
|
||||
type arithmeticOperation = [
|
||||
| #Add
|
||||
| #Multiply
|
||||
| #Subtract
|
||||
| #Divide
|
||||
| #Exponentiate
|
||||
| #Log
|
||||
]
|
||||
|
||||
let arithmeticToFn = (arithmetic: arithmeticOperation) =>
|
||||
switch arithmetic {
|
||||
| #Add => \"+."
|
||||
| #Multiply => \"*."
|
||||
| #Subtract => \"-."
|
||||
| #Exponentiate => \"**"
|
||||
| #Divide => \"/."
|
||||
| #Log => (a, b) => log(a) /. log(b)
|
||||
}
|
||||
|
||||
type toFloat = [
|
||||
| #Cdf(float)
|
||||
| #Inv(float)
|
||||
| #Mean
|
||||
| #Pdf(float)
|
||||
| #Sample
|
||||
]
|
||||
|
||||
type toDist =
|
||||
| Normalize
|
||||
| ToPointSet
|
||||
| ToSampleSet(int)
|
||||
| Truncate(option<float>, option<float>)
|
||||
| Inspect
|
||||
|
||||
type toFloatArray = Sample(int)
|
||||
|
||||
type fromDist =
|
||||
| ToFloat(toFloat)
|
||||
| ToDist(toDist)
|
||||
| ToDistCombination(direction, arithmeticOperation, [#Dist(genericDist) | #Float(float)])
|
||||
| ToString
|
||||
|
||||
type singleParamaterFunction =
|
||||
| FromDist(fromDist)
|
||||
| FromFloat(fromDist)
|
||||
|
||||
type genericFunctionCallInfo =
|
||||
| FromDist(fromDist, genericDist)
|
||||
| FromFloat(fromDist, float)
|
||||
| Mixture(array<(genericDist, float)>)
|
||||
|
||||
let distCallToString = (distFunction: fromDist): string =>
|
||||
switch distFunction {
|
||||
| ToFloat(#Cdf(r)) => `cdf(${E.Float.toFixed(r)})`
|
||||
| ToFloat(#Inv(r)) => `inv(${E.Float.toFixed(r)})`
|
||||
| ToFloat(#Mean) => `mean`
|
||||
| ToFloat(#Pdf(r)) => `pdf(${E.Float.toFixed(r)})`
|
||||
| ToFloat(#Sample) => `sample`
|
||||
| ToDist(Normalize) => `normalize`
|
||||
| ToDist(ToPointSet) => `toPointSet`
|
||||
| ToDist(ToSampleSet(r)) => `toSampleSet(${E.I.toString(r)})`
|
||||
| ToDist(Truncate(_, _)) => `truncate`
|
||||
| ToDist(Inspect) => `inspect`
|
||||
| ToString => `toString`
|
||||
| ToDistCombination(Algebraic, _, _) => `algebraic`
|
||||
| ToDistCombination(Pointwise, _, _) => `pointwise`
|
||||
}
|
||||
|
||||
let toString = (d: genericFunctionCallInfo): string =>
|
||||
switch d {
|
||||
| FromDist(f, _) | FromFloat(f, _) => distCallToString(f)
|
||||
| Mixture(_) => `mixture`
|
||||
}
|
||||
}
|
48
packages/squiggle-lang/src/rescript/GenericDist/README.md
Normal file
48
packages/squiggle-lang/src/rescript/GenericDist/README.md
Normal file
|
@ -0,0 +1,48 @@
|
|||
# Generic Distribution Library
|
||||
|
||||
This library provides one interface to generic distributions. These distributions can either be symbolic, point set (xy-coordinates of the shape), or sample set (arrays of random samples).
|
||||
|
||||
Different internal formats (symbolic, point set, sample set) allow for benefits and features. It's common for distributions to be converted into either point sets or sample sets to enable certain functions.
|
||||
|
||||
In addition to this interface, there's a second, generic function, for calling functions on this generic distribution type. This ``genericOperation`` standardizes the inputs and outputs for these various function calls. See it's ``run()`` function.
|
||||
|
||||
Performance is very important. Some operations can take a long time to run, and even then, be inaccurate. Because of this, we plan to have a lot of logging and stack tracing functionality eventually built in.
|
||||
|
||||
## Diagram of Distribution Types
|
||||
```mermaid
|
||||
graph TD
|
||||
A[Generic Distribution] -->B{Point Set}
|
||||
A --> C{Sample Set}
|
||||
A --> D{Symbolic}
|
||||
B ---> continuous(Continuous)
|
||||
B ---> discrete(Discrete)
|
||||
B --> mixed(Mixed)
|
||||
continuous -.-> XYshape(XYshape)
|
||||
discrete -.-> XYshape(XYshape)
|
||||
mixed -.-> continuous
|
||||
mixed -.-> discrete
|
||||
D --> Normal(Normal)
|
||||
D --> Lognormal(Lognormal)
|
||||
D --> Triangular(Triangular)
|
||||
D --> Beta(Beta)
|
||||
D --> Uniform(Uniform)
|
||||
D --> Float(Float)
|
||||
D --> Exponential(Exponential)
|
||||
D --> Cauchy(Cauchy)
|
||||
```
|
||||
|
||||
## Diagram of Generic Distribution Types
|
||||
|
||||
## Todo
|
||||
- [ ] Lots of cleanup
|
||||
- [ ] Simple test story
|
||||
- [ ] Provide decent stack traces for key calls in GenericOperation. This could be very useful for debugging.
|
||||
- [ ] Cleanup Sample Set library
|
||||
- [ ] Add memoization for calculations
|
||||
- [ ] Performance bechmarking reports
|
||||
- [ ] Remove most of DistPlus, much of which is not needed anymore
|
||||
- [ ] More functions for Sample Set, which is very minimal now
|
||||
- [ ] Allow these functions to be run on web workers
|
||||
- [ ] Refactor interpreter to use GenericDist. This might not be necessary, as the new reducer-inspired interpreter is integrated.
|
||||
|
||||
## More todos
|
|
@ -115,6 +115,7 @@ let combineShapesContinuousContinuous = (
|
|||
| #Multiply => (m1, m2) => m1 *. m2
|
||||
| #Divide => (m1, mInv2) => m1 *. mInv2
|
||||
| #Exponentiate => (m1, mInv2) => m1 ** mInv2
|
||||
| #Log => (m1, m2) => log(m1) /. log(m2)
|
||||
} // note: here, mInv2 = mean(1 / t2) ~= 1 / mean(t2)
|
||||
|
||||
// TODO: I don't know what the variances are for exponentatiation
|
||||
|
@ -232,6 +233,7 @@ let combineShapesContinuousDiscrete = (
|
|||
}
|
||||
| #Multiply
|
||||
| #Exponentiate
|
||||
| #Log
|
||||
| #Divide =>
|
||||
for j in 0 to t2n - 1 {
|
||||
// creates a new continuous shape for each one of the discrete points, and collects them in outXYShapes.
|
||||
|
|
|
@ -34,6 +34,7 @@ let toMixed = mapToAll((
|
|||
),
|
||||
))
|
||||
|
||||
//TODO WARNING: The combineAlgebraicallyWithDiscrete will break for subtraction and division, like, discrete - continous
|
||||
let combineAlgebraically = (op: Operation.algebraicOperation, t1: t, t2: t): t =>
|
||||
switch (t1, t2) {
|
||||
| (Continuous(m1), Continuous(m2)) =>
|
||||
|
@ -41,7 +42,8 @@ let combineAlgebraically = (op: Operation.algebraicOperation, t1: t, t2: t): t =
|
|||
| (Continuous(m1), Discrete(m2))
|
||||
| (Discrete(m2), Continuous(m1)) =>
|
||||
Continuous.combineAlgebraicallyWithDiscrete(op, m1, m2) |> Continuous.T.toPointSetDist
|
||||
| (Discrete(m1), Discrete(m2)) => Discrete.combineAlgebraically(op, m1, m2) |> Discrete.T.toPointSetDist
|
||||
| (Discrete(m1), Discrete(m2)) =>
|
||||
Discrete.combineAlgebraically(op, m1, m2) |> Discrete.T.toPointSetDist
|
||||
| (m1, m2) => Mixed.combineAlgebraically(op, toMixed(m1), toMixed(m2)) |> Mixed.T.toPointSetDist
|
||||
}
|
||||
|
||||
|
@ -196,7 +198,7 @@ let sampleNRendered = (n, dist) => {
|
|||
let operate = (distToFloatOp: Operation.distToFloatOperation, s): float =>
|
||||
switch distToFloatOp {
|
||||
| #Pdf(f) => pdf(f, s)
|
||||
| #Cdf(f) => pdf(f, s)
|
||||
| #Cdf(f) => cdf(f, s)
|
||||
| #Inv(f) => inv(f, s)
|
||||
| #Sample => sample(s)
|
||||
| #Mean => T.mean(s)
|
||||
|
|
|
@ -254,7 +254,7 @@ module PointwiseCombination = {
|
|||
j = t2n;
|
||||
continue;
|
||||
} else {
|
||||
console.log("Error!", i, j);
|
||||
console.log("PointwiseCombination Error", i, j);
|
||||
}
|
||||
|
||||
outX.push(x);
|
||||
|
|
|
@ -1,3 +1,5 @@
|
|||
// TODO: Refactor to raise correct error when not enough samples
|
||||
|
||||
module Internals = {
|
||||
module Types = {
|
||||
type samplingStats = {
|
||||
|
|
|
@ -165,6 +165,7 @@ module Uniform = {
|
|||
let mean = (t: t) => Ok(Jstat.Uniform.mean(t.low, t.high))
|
||||
let toString = ({low, high}: t) => j`Uniform($low,$high)`
|
||||
let truncate = (low, high, t: t): t => {
|
||||
//todo: add check
|
||||
let newLow = max(E.O.default(neg_infinity, low), t.low)
|
||||
let newHigh = min(E.O.default(infinity, high), t.high)
|
||||
{low: newLow, high: newHigh}
|
||||
|
|
|
@ -32,6 +32,17 @@ module U = {
|
|||
let id = e => e
|
||||
}
|
||||
|
||||
module Tuple2 = {
|
||||
let first = (v: ('a, 'b)) => {
|
||||
let (a, _) = v
|
||||
a
|
||||
}
|
||||
let second = (v: ('a, 'b)) => {
|
||||
let (_, b) = v
|
||||
b
|
||||
}
|
||||
}
|
||||
|
||||
module O = {
|
||||
let dimap = (sFn, rFn, e) =>
|
||||
switch e {
|
||||
|
@ -164,6 +175,10 @@ module R = {
|
|||
errorCondition(r) ? Error(errorMessage) : Ok(r)
|
||||
}
|
||||
|
||||
module R2 = {
|
||||
let fmap = (a,b) => R.fmap(b,a)
|
||||
}
|
||||
|
||||
let safe_fn_of_string = (fn, s: string): option<'a> =>
|
||||
try Some(fn(s)) catch {
|
||||
| _ => None
|
||||
|
@ -443,7 +458,7 @@ module A = {
|
|||
}
|
||||
|
||||
module A2 = {
|
||||
let fmap = (a, b) => Array.map(b, a)
|
||||
let fmap = (a,b) => A.fmap(b,a)
|
||||
let joinWith = (a, b) => A.joinWith(b, a)
|
||||
}
|
||||
|
||||
|
|
|
@ -7,10 +7,11 @@ type algebraicOperation = [
|
|||
| #Subtract
|
||||
| #Divide
|
||||
| #Exponentiate
|
||||
| #Log
|
||||
]
|
||||
@genType
|
||||
type pointwiseOperation = [#Add | #Multiply | #Exponentiate]
|
||||
type scaleOperation = [#Multiply | #Exponentiate | #Log]
|
||||
type scaleOperation = [#Multiply | #Exponentiate | #Log | #Divide]
|
||||
type distToFloatOperation = [
|
||||
| #Pdf(float)
|
||||
| #Cdf(float)
|
||||
|
@ -28,6 +29,7 @@ module Algebraic = {
|
|||
| #Multiply => \"*."
|
||||
| #Exponentiate => \"**"
|
||||
| #Divide => \"/."
|
||||
| #Log => (a, b) => log(a) /. log(b)
|
||||
}
|
||||
|
||||
let applyFn = (t, f1, f2) =>
|
||||
|
@ -43,6 +45,7 @@ module Algebraic = {
|
|||
| #Multiply => "*"
|
||||
| #Exponentiate => "**"
|
||||
| #Divide => "/"
|
||||
| #Log => "log"
|
||||
}
|
||||
|
||||
let format = (a, b, c) => b ++ (" " ++ (toString(a) ++ (" " ++ c)))
|
||||
|
@ -79,6 +82,7 @@ module Scale = {
|
|||
let toFn = x =>
|
||||
switch x {
|
||||
| #Multiply => \"*."
|
||||
| #Divide => \"/."
|
||||
| #Exponentiate => \"**"
|
||||
| #Log => (a, b) => log(a) /. log(b)
|
||||
}
|
||||
|
@ -86,6 +90,7 @@ module Scale = {
|
|||
let format = (operation: t, value, scaleBy) =>
|
||||
switch operation {
|
||||
| #Multiply => j`verticalMultiply($value, $scaleBy) `
|
||||
| #Divide => j`verticalDivide($value, $scaleBy) `
|
||||
| #Exponentiate => j`verticalExponentiate($value, $scaleBy) `
|
||||
| #Log => j`verticalLog($value, $scaleBy) `
|
||||
}
|
||||
|
@ -93,6 +98,7 @@ module Scale = {
|
|||
let toIntegralSumCacheFn = x =>
|
||||
switch x {
|
||||
| #Multiply => (a, b) => Some(a *. b)
|
||||
| #Divide => (a, b) => Some(a /. b)
|
||||
| #Exponentiate => (_, _) => None
|
||||
| #Log => (_, _) => None
|
||||
}
|
||||
|
@ -100,6 +106,7 @@ module Scale = {
|
|||
let toIntegralCacheFn = x =>
|
||||
switch x {
|
||||
| #Multiply => (_, _) => None // TODO: this could probably just be multiplied out (using Continuous.scaleBy)
|
||||
| #Divide => (_, _) => None
|
||||
| #Exponentiate => (_, _) => None
|
||||
| #Log => (_, _) => None
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue
Block a user