Merge pull request #525 from quantified-uncertainty/pass-environment

Pass environment down to GenericDist
This commit is contained in:
Sam Nolan 2022-05-13 16:28:21 -04:00 committed by GitHub
commit 2da5d5394e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 77 additions and 57 deletions

View File

@ -1,13 +1,6 @@
module ExpressionValue = ReducerInterface_ExpressionValue module ExpressionValue = ReducerInterface_ExpressionValue
type expressionValue = ReducerInterface_ExpressionValue.expressionValue type expressionValue = ReducerInterface_ExpressionValue.expressionValue
let defaultEnv: DistributionOperation.env = {
sampleCount: MagicNumbers.Environment.defaultSampleCount,
xyPointLength: MagicNumbers.Environment.defaultXYPointLength,
}
let runGenericOperation = DistributionOperation.run(~env=defaultEnv)
module Helpers = { module Helpers = {
let arithmeticMap = r => let arithmeticMap = r =>
switch r { switch r {
@ -39,37 +32,44 @@ module Helpers = {
let toFloatFn = ( let toFloatFn = (
fnCall: DistributionTypes.DistributionOperation.toFloat, fnCall: DistributionTypes.DistributionOperation.toFloat,
dist: DistributionTypes.genericDist, dist: DistributionTypes.genericDist,
~env: DistributionOperation.env,
) => { ) => {
FromDist(DistributionTypes.DistributionOperation.ToFloat(fnCall), dist) FromDist(DistributionTypes.DistributionOperation.ToFloat(fnCall), dist)
->runGenericOperation ->DistributionOperation.run(~env)
->Some ->Some
} }
let toStringFn = ( let toStringFn = (
fnCall: DistributionTypes.DistributionOperation.toString, fnCall: DistributionTypes.DistributionOperation.toString,
dist: DistributionTypes.genericDist, dist: DistributionTypes.genericDist,
~env: DistributionOperation.env,
) => { ) => {
FromDist(DistributionTypes.DistributionOperation.ToString(fnCall), dist) FromDist(DistributionTypes.DistributionOperation.ToString(fnCall), dist)
->runGenericOperation ->DistributionOperation.run(~env)
->Some ->Some
} }
let toBoolFn = ( let toBoolFn = (
fnCall: DistributionTypes.DistributionOperation.toBool, fnCall: DistributionTypes.DistributionOperation.toBool,
dist: DistributionTypes.genericDist, dist: DistributionTypes.genericDist,
~env: DistributionOperation.env,
) => { ) => {
FromDist(DistributionTypes.DistributionOperation.ToBool(fnCall), dist) FromDist(DistributionTypes.DistributionOperation.ToBool(fnCall), dist)
->runGenericOperation ->DistributionOperation.run(~env)
->Some ->Some
} }
let toDistFn = (fnCall: DistributionTypes.DistributionOperation.toDist, dist) => { let toDistFn = (
fnCall: DistributionTypes.DistributionOperation.toDist,
dist,
~env: DistributionOperation.env,
) => {
FromDist(DistributionTypes.DistributionOperation.ToDist(fnCall), dist) FromDist(DistributionTypes.DistributionOperation.ToDist(fnCall), dist)
->runGenericOperation ->DistributionOperation.run(~env)
->Some ->Some
} }
let twoDiststoDistFn = (direction, arithmetic, dist1, dist2) => { let twoDiststoDistFn = (direction, arithmetic, dist1, dist2, ~env: DistributionOperation.env) => {
FromDist( FromDist(
DistributionTypes.DistributionOperation.ToDistCombination( DistributionTypes.DistributionOperation.ToDistCombination(
direction, direction,
@ -77,7 +77,7 @@ module Helpers = {
#Dist(dist2), #Dist(dist2),
), ),
dist1, dist1,
)->runGenericOperation )->DistributionOperation.run(~env)
} }
let parseNumber = (args: expressionValue): Belt.Result.t<float, string> => let parseNumber = (args: expressionValue): Belt.Result.t<float, string> =>
@ -104,33 +104,38 @@ module Helpers = {
let mixtureWithGivenWeights = ( let mixtureWithGivenWeights = (
distributions: array<DistributionTypes.genericDist>, distributions: array<DistributionTypes.genericDist>,
weights: array<float>, weights: array<float>,
~env: DistributionOperation.env,
): DistributionOperation.outputType => ): DistributionOperation.outputType =>
E.A.length(distributions) == E.A.length(weights) E.A.length(distributions) == E.A.length(weights)
? Mixture(Belt.Array.zip(distributions, weights))->runGenericOperation ? Mixture(Belt.Array.zip(distributions, weights))->DistributionOperation.run(~env)
: GenDistError( : GenDistError(
ArgumentError("Error, mixture call has different number of distributions and weights"), ArgumentError("Error, mixture call has different number of distributions and weights"),
) )
let mixtureWithDefaultWeights = ( let mixtureWithDefaultWeights = (
distributions: array<DistributionTypes.genericDist>, distributions: array<DistributionTypes.genericDist>,
~env: DistributionOperation.env,
): DistributionOperation.outputType => { ): DistributionOperation.outputType => {
let length = E.A.length(distributions) let length = E.A.length(distributions)
let weights = Belt.Array.make(length, 1.0 /. Belt.Int.toFloat(length)) let weights = Belt.Array.make(length, 1.0 /. Belt.Int.toFloat(length))
mixtureWithGivenWeights(distributions, weights) mixtureWithGivenWeights(distributions, weights, ~env)
} }
let mixture = (args: array<expressionValue>): DistributionOperation.outputType => { let mixture = (
args: array<expressionValue>,
~env: DistributionOperation.env,
): DistributionOperation.outputType => {
let error = (err: string): DistributionOperation.outputType => let error = (err: string): DistributionOperation.outputType =>
err->DistributionTypes.ArgumentError->GenDistError err->DistributionTypes.ArgumentError->GenDistError
switch args { switch args {
| [EvArray(distributions)] => | [EvArray(distributions)] =>
switch parseDistributionArray(distributions) { switch parseDistributionArray(distributions) {
| Ok(distrs) => mixtureWithDefaultWeights(distrs) | Ok(distrs) => mixtureWithDefaultWeights(distrs, ~env)
| Error(err) => error(err) | Error(err) => error(err)
} }
| [EvArray(distributions), EvArray(weights)] => | [EvArray(distributions), EvArray(weights)] =>
switch (parseDistributionArray(distributions), parseNumberArray(weights)) { switch (parseDistributionArray(distributions), parseNumberArray(weights)) {
| (Ok(distrs), Ok(wghts)) => mixtureWithGivenWeights(distrs, wghts) | (Ok(distrs), Ok(wghts)) => mixtureWithGivenWeights(distrs, wghts, ~env)
| (Error(err), Ok(_)) => error(err) | (Error(err), Ok(_)) => error(err)
| (Ok(_), Error(err)) => error(err) | (Ok(_), Error(err)) => error(err)
| (Error(err1), Error(err2)) => error(`${err1}|${err2}`) | (Error(err1), Error(err2)) => error(`${err1}|${err2}`)
@ -143,14 +148,14 @@ module Helpers = {
Belt.Array.slice(args, ~offset=0, ~len=E.A.length(args) - 1), Belt.Array.slice(args, ~offset=0, ~len=E.A.length(args) - 1),
) )
switch E.R.merge(distributions, weights) { switch E.R.merge(distributions, weights) {
| Ok(d, w) => mixtureWithGivenWeights(d, w) | Ok(d, w) => mixtureWithGivenWeights(d, w, ~env)
| Error(err) => error(err) | Error(err) => error(err)
} }
} }
| Some(EvNumber(_)) | Some(EvNumber(_))
| Some(EvDistribution(_)) => | Some(EvDistribution(_)) =>
switch parseDistributionArray(args) { switch parseDistributionArray(args) {
| Ok(distributions) => mixtureWithDefaultWeights(distributions) | Ok(distributions) => mixtureWithDefaultWeights(distributions, ~env)
| Error(err) => error(err) | Error(err) => error(err)
} }
| _ => error("Last argument of mx must be array or distribution") | _ => error("Last argument of mx must be array or distribution")
@ -193,9 +198,10 @@ module SymbolicConstructors = {
} }
} }
let dispatchToGenericOutput = (call: ExpressionValue.functionCall, _environment): option< let dispatchToGenericOutput = (
DistributionOperation.outputType, call: ExpressionValue.functionCall,
> => { env: DistributionOperation.env,
): option<DistributionOperation.outputType> => {
let (fnName, args) = call let (fnName, args) = call
switch (fnName, args) { switch (fnName, args) {
| ("exponential" as fnName, [EvNumber(f)]) => | ("exponential" as fnName, [EvNumber(f)]) =>
@ -215,13 +221,13 @@ let dispatchToGenericOutput = (call: ExpressionValue.functionCall, _environment)
SymbolicConstructors.threeFloat(fnName) SymbolicConstructors.threeFloat(fnName)
->E.R.bind(r => r(f1, f2, f3)) ->E.R.bind(r => r(f1, f2, f3))
->SymbolicConstructors.symbolicResultToOutput ->SymbolicConstructors.symbolicResultToOutput
| ("sample", [EvDistribution(dist)]) => Helpers.toFloatFn(#Sample, dist) | ("sample", [EvDistribution(dist)]) => Helpers.toFloatFn(#Sample, dist, ~env)
| ("mean", [EvDistribution(dist)]) => Helpers.toFloatFn(#Mean, dist) | ("mean", [EvDistribution(dist)]) => Helpers.toFloatFn(#Mean, dist, ~env)
| ("integralSum", [EvDistribution(dist)]) => Helpers.toFloatFn(#IntegralSum, dist) | ("integralSum", [EvDistribution(dist)]) => Helpers.toFloatFn(#IntegralSum, dist, ~env)
| ("toString", [EvDistribution(dist)]) => Helpers.toStringFn(ToString, dist) | ("toString", [EvDistribution(dist)]) => Helpers.toStringFn(ToString, dist, ~env)
| ("toSparkline", [EvDistribution(dist)]) => Helpers.toStringFn(ToSparkline(20), dist) | ("toSparkline", [EvDistribution(dist)]) => Helpers.toStringFn(ToSparkline(20), dist, ~env)
| ("toSparkline", [EvDistribution(dist), EvNumber(n)]) => | ("toSparkline", [EvDistribution(dist), EvNumber(n)]) =>
Helpers.toStringFn(ToSparkline(Belt.Float.toInt(n)), dist) Helpers.toStringFn(ToSparkline(Belt.Float.toInt(n)), dist, ~env)
| ("exp", [EvDistribution(a)]) => | ("exp", [EvDistribution(a)]) =>
// https://mathjs.org/docs/reference/functions/exp.html // https://mathjs.org/docs/reference/functions/exp.html
Helpers.twoDiststoDistFn( Helpers.twoDiststoDistFn(
@ -229,60 +235,74 @@ let dispatchToGenericOutput = (call: ExpressionValue.functionCall, _environment)
"pow", "pow",
GenericDist.fromFloat(MagicNumbers.Math.e), GenericDist.fromFloat(MagicNumbers.Math.e),
a, a,
~env,
)->Some )->Some
| ("normalize", [EvDistribution(dist)]) => Helpers.toDistFn(Normalize, dist) | ("normalize", [EvDistribution(dist)]) => Helpers.toDistFn(Normalize, dist, ~env)
| ("klDivergence", [EvDistribution(a), EvDistribution(b)]) => | ("klDivergence", [EvDistribution(a), EvDistribution(b)]) =>
Some(runGenericOperation(FromDist(ToScore(KLDivergence(b)), a))) Some(DistributionOperation.run(FromDist(ToScore(KLDivergence(b)), a), ~env))
| ("isNormalized", [EvDistribution(dist)]) => Helpers.toBoolFn(IsNormalized, dist) | ("isNormalized", [EvDistribution(dist)]) => Helpers.toBoolFn(IsNormalized, dist, ~env)
| ("toPointSet", [EvDistribution(dist)]) => Helpers.toDistFn(ToPointSet, dist) | ("toPointSet", [EvDistribution(dist)]) => Helpers.toDistFn(ToPointSet, dist, ~env)
| ("scaleLog", [EvDistribution(dist)]) => | ("scaleLog", [EvDistribution(dist)]) =>
Helpers.toDistFn(Scale(#Logarithm, MagicNumbers.Math.e), dist) Helpers.toDistFn(Scale(#Logarithm, MagicNumbers.Math.e), dist, ~env)
| ("scaleLog10", [EvDistribution(dist)]) => Helpers.toDistFn(Scale(#Logarithm, 10.0), dist) | ("scaleLog10", [EvDistribution(dist)]) => Helpers.toDistFn(Scale(#Logarithm, 10.0), dist, ~env)
| ("scaleLog", [EvDistribution(dist), EvNumber(float)]) => | ("scaleLog", [EvDistribution(dist), EvNumber(float)]) =>
Helpers.toDistFn(Scale(#Logarithm, float), dist) Helpers.toDistFn(Scale(#Logarithm, float), dist, ~env)
| ("scaleLogWithThreshold", [EvDistribution(dist), EvNumber(base), EvNumber(eps)]) => | ("scaleLogWithThreshold", [EvDistribution(dist), EvNumber(base), EvNumber(eps)]) =>
Helpers.toDistFn(Scale(#LogarithmWithThreshold(eps), base), dist) Helpers.toDistFn(Scale(#LogarithmWithThreshold(eps), base), dist, ~env)
| ("scalePow", [EvDistribution(dist), EvNumber(float)]) => | ("scalePow", [EvDistribution(dist), EvNumber(float)]) =>
Helpers.toDistFn(Scale(#Power, float), dist) Helpers.toDistFn(Scale(#Power, float), dist, ~env)
| ("scaleExp", [EvDistribution(dist)]) => | ("scaleExp", [EvDistribution(dist)]) =>
Helpers.toDistFn(Scale(#Power, MagicNumbers.Math.e), dist) Helpers.toDistFn(Scale(#Power, MagicNumbers.Math.e), dist, ~env)
| ("cdf", [EvDistribution(dist), EvNumber(float)]) => Helpers.toFloatFn(#Cdf(float), dist) | ("cdf", [EvDistribution(dist), EvNumber(float)]) => Helpers.toFloatFn(#Cdf(float), dist, ~env)
| ("pdf", [EvDistribution(dist), EvNumber(float)]) => Helpers.toFloatFn(#Pdf(float), dist) | ("pdf", [EvDistribution(dist), EvNumber(float)]) => Helpers.toFloatFn(#Pdf(float), dist, ~env)
| ("inv", [EvDistribution(dist), EvNumber(float)]) => Helpers.toFloatFn(#Inv(float), dist) | ("inv", [EvDistribution(dist), EvNumber(float)]) => Helpers.toFloatFn(#Inv(float), dist, ~env)
| ("toSampleSet", [EvDistribution(dist), EvNumber(float)]) => | ("toSampleSet", [EvDistribution(dist), EvNumber(float)]) =>
Helpers.toDistFn(ToSampleSet(Belt.Int.fromFloat(float)), dist) Helpers.toDistFn(ToSampleSet(Belt.Int.fromFloat(float)), dist, ~env)
| ("toSampleSet", [EvDistribution(dist)]) => | ("toSampleSet", [EvDistribution(dist)]) =>
Helpers.toDistFn(ToSampleSet(MagicNumbers.Environment.defaultSampleCount), dist) Helpers.toDistFn(ToSampleSet(env.sampleCount), dist, ~env)
| ("fromSamples", [EvArray(inputArray)]) => { | ("fromSamples", [EvArray(inputArray)]) => {
let _wrapInputErrors = x => SampleSetDist.NonNumericInput(x) let _wrapInputErrors = x => SampleSetDist.NonNumericInput(x)
let parsedArray = Helpers.parseNumberArray(inputArray)->E.R2.errMap(_wrapInputErrors) let parsedArray = Helpers.parseNumberArray(inputArray)->E.R2.errMap(_wrapInputErrors)
switch parsedArray { switch parsedArray {
| Ok(array) => runGenericOperation(FromSamples(array)) | Ok(array) => DistributionOperation.run(FromSamples(array), ~env)
| Error(e) => GenDistError(SampleSetError(e)) | Error(e) => GenDistError(SampleSetError(e))
}->Some }->Some
} }
| ("inspect", [EvDistribution(dist)]) => Helpers.toDistFn(Inspect, dist) | ("inspect", [EvDistribution(dist)]) => Helpers.toDistFn(Inspect, dist, ~env)
| ("truncateLeft", [EvDistribution(dist), EvNumber(float)]) => | ("truncateLeft", [EvDistribution(dist), EvNumber(float)]) =>
Helpers.toDistFn(Truncate(Some(float), None), dist) Helpers.toDistFn(Truncate(Some(float), None), dist, ~env)
| ("truncateRight", [EvDistribution(dist), EvNumber(float)]) => | ("truncateRight", [EvDistribution(dist), EvNumber(float)]) =>
Helpers.toDistFn(Truncate(None, Some(float)), dist) Helpers.toDistFn(Truncate(None, Some(float)), dist, ~env)
| ("truncate", [EvDistribution(dist), EvNumber(float1), EvNumber(float2)]) => | ("truncate", [EvDistribution(dist), EvNumber(float1), EvNumber(float2)]) =>
Helpers.toDistFn(Truncate(Some(float1), Some(float2)), dist) Helpers.toDistFn(Truncate(Some(float1), Some(float2)), dist, ~env)
| ("mx" | "mixture", args) => Helpers.mixture(args)->Some | ("mx" | "mixture", args) => Helpers.mixture(args, ~env)->Some
| ("log", [EvDistribution(a)]) => | ("log", [EvDistribution(a)]) =>
Helpers.twoDiststoDistFn( Helpers.twoDiststoDistFn(
Algebraic(AsDefault), Algebraic(AsDefault),
"log", "log",
a, a,
GenericDist.fromFloat(MagicNumbers.Math.e), GenericDist.fromFloat(MagicNumbers.Math.e),
~env,
)->Some )->Some
| ("log10", [EvDistribution(a)]) => | ("log10", [EvDistribution(a)]) =>
Helpers.twoDiststoDistFn(Algebraic(AsDefault), "log", a, GenericDist.fromFloat(10.0))->Some Helpers.twoDiststoDistFn(
Algebraic(AsDefault),
"log",
a,
GenericDist.fromFloat(10.0),
~env,
)->Some
| ("unaryMinus", [EvDistribution(a)]) => | ("unaryMinus", [EvDistribution(a)]) =>
Helpers.twoDiststoDistFn(Algebraic(AsDefault), "multiply", a, GenericDist.fromFloat(-1.0))->Some Helpers.twoDiststoDistFn(
Algebraic(AsDefault),
"multiply",
a,
GenericDist.fromFloat(-1.0),
~env,
)->Some
| (("add" | "multiply" | "subtract" | "divide" | "pow" | "log") as arithmetic, [_, _] as args) => | (("add" | "multiply" | "subtract" | "divide" | "pow" | "log") as arithmetic, [_, _] as args) =>
Helpers.catchAndConvertTwoArgsToDists(args)->E.O2.fmap(((fst, snd)) => Helpers.catchAndConvertTwoArgsToDists(args)->E.O2.fmap(((fst, snd)) =>
Helpers.twoDiststoDistFn(Algebraic(AsDefault), arithmetic, fst, snd) Helpers.twoDiststoDistFn(Algebraic(AsDefault), arithmetic, fst, snd, ~env)
) )
| ( | (
("dotAdd" ("dotAdd"
@ -293,7 +313,7 @@ let dispatchToGenericOutput = (call: ExpressionValue.functionCall, _environment)
[_, _] as args, [_, _] as args,
) => ) =>
Helpers.catchAndConvertTwoArgsToDists(args)->E.O2.fmap(((fst, snd)) => Helpers.catchAndConvertTwoArgsToDists(args)->E.O2.fmap(((fst, snd)) =>
Helpers.twoDiststoDistFn(Pointwise, arithmetic, fst, snd) Helpers.twoDiststoDistFn(Pointwise, arithmetic, fst, snd, ~env)
) )
| ("dotExp", [EvDistribution(a)]) => | ("dotExp", [EvDistribution(a)]) =>
Helpers.twoDiststoDistFn( Helpers.twoDiststoDistFn(
@ -301,6 +321,7 @@ let dispatchToGenericOutput = (call: ExpressionValue.functionCall, _environment)
"dotPow", "dotPow",
GenericDist.fromFloat(MagicNumbers.Math.e), GenericDist.fromFloat(MagicNumbers.Math.e),
a, a,
~env,
)->Some )->Some
| _ => None | _ => None
} }

View File

@ -1,4 +1,3 @@
let defaultEnv: DistributionOperation.env
let dispatch: ( let dispatch: (
ReducerInterface_ExpressionValue.functionCall, ReducerInterface_ExpressionValue.functionCall,
ReducerInterface_ExpressionValue.environment, ReducerInterface_ExpressionValue.environment,

View File

@ -77,7 +77,7 @@ let distributionErrorToString = DistributionTypes.Error.toString
type lambdaValue = ReducerInterface_ExpressionValue.lambdaValue type lambdaValue = ReducerInterface_ExpressionValue.lambdaValue
@genType @genType
let defaultSamplingEnv = ReducerInterface_GenericDistribution.defaultEnv let defaultSamplingEnv = DistributionOperation.defaultEnv
@genType @genType
type environment = ReducerInterface_ExpressionValue.environment type environment = ReducerInterface_ExpressionValue.environment