Merged with master

This commit is contained in:
Ozzie Gooen 2022-05-15 10:54:16 -04:00
commit 47d7ef49cf
13 changed files with 243 additions and 101 deletions

View File

@ -28,10 +28,10 @@
"@storybook/react": "^6.4.22", "@storybook/react": "^6.4.22",
"@testing-library/jest-dom": "^5.16.4", "@testing-library/jest-dom": "^5.16.4",
"@testing-library/react": "^13.2.0", "@testing-library/react": "^13.2.0",
"@testing-library/user-event": "^14.1.1", "@testing-library/user-event": "^14.2.0",
"@types/jest": "^27.5.0", "@types/jest": "^27.5.0",
"@types/lodash": "^4.14.182", "@types/lodash": "^4.14.182",
"@types/node": "^17.0.31", "@types/node": "^17.0.32",
"@types/react": "^18.0.9", "@types/react": "^18.0.9",
"@types/react-dom": "^18.0.2", "@types/react-dom": "^18.0.2",
"@types/styled-components": "^5.1.24", "@types/styled-components": "^5.1.24",

View File

@ -76,7 +76,9 @@ export const FunctionChart: React.FC<FunctionChartProps> = ({
chartSettings.count chartSettings.count
); );
type point = { x: number; value: result<Distribution, string> }; type point = { x: number; value: result<Distribution, string> };
let valueData: point[] = data1.map((x) => { let valueData: point[] = React.useMemo(
() =>
data1.map((x) => {
let result = runForeign(fn, [x], environment); let result = runForeign(fn, [x], environment);
if (result.tag === "Ok") { if (result.tag === "Ok") {
if (result.value.tag == "distribution") { if (result.value.tag == "distribution") {
@ -97,7 +99,9 @@ export const FunctionChart: React.FC<FunctionChartProps> = ({
value: { tag: "Error", value: errorValueToString(result.value) }, value: { tag: "Error", value: errorValueToString(result.value) },
}; };
} }
}); }),
[environment, fn]
);
let initialPartition: [ let initialPartition: [
{ x: number; value: Distribution }[], { x: number; value: Distribution }[],
@ -141,10 +145,10 @@ export const FunctionChart: React.FC<FunctionChartProps> = ({
/> />
{showChart} {showChart}
{_.entries(groupedErrors).map(([errorName, errorPoints]) => ( {_.entries(groupedErrors).map(([errorName, errorPoints]) => (
<ErrorBox heading={errorName}> <ErrorBox key={errorName} heading={errorName}>
Values:{" "} Values:{" "}
{errorPoints {errorPoints
.map((r) => <NumberShower number={r.x} />) .map((r, i) => <NumberShower key={i} number={r.x} />)
.reduce((a, b) => ( .reduce((a, b) => (
<> <>
{a}, {b} {a}, {b}

View File

@ -148,8 +148,9 @@ const SquiggleItem: React.FC<SquiggleItemProps> = ({
case "array": case "array":
return ( return (
<VariableBox heading="Array" showTypes={showTypes}> <VariableBox heading="Array" showTypes={showTypes}>
{expression.value.map((r) => ( {expression.value.map((r, i) => (
<SquiggleItem <SquiggleItem
key={i}
expression={r} expression={r}
width={width !== undefined ? width - 20 : width} width={width !== undefined ? width - 20 : width}
height={50} height={50}
@ -166,7 +167,7 @@ const SquiggleItem: React.FC<SquiggleItemProps> = ({
return ( return (
<VariableBox heading="Record" showTypes={showTypes}> <VariableBox heading="Record" showTypes={showTypes}>
{Object.entries(expression.value).map(([key, r]) => ( {Object.entries(expression.value).map(([key, r]) => (
<> <div key={key}>
<RecordKeyHeader>{key}</RecordKeyHeader> <RecordKeyHeader>{key}</RecordKeyHeader>
<SquiggleItem <SquiggleItem
expression={r} expression={r}
@ -178,14 +179,14 @@ const SquiggleItem: React.FC<SquiggleItemProps> = ({
chartSettings={chartSettings} chartSettings={chartSettings}
environment={environment} environment={environment}
/> />
</> </div>
))} ))}
</VariableBox> </VariableBox>
); );
case "arraystring": case "arraystring":
return ( return (
<VariableBox heading="Array String" showTypes={showTypes}> <VariableBox heading="Array String" showTypes={showTypes}>
{expression.value.map((r) => `"${r}"`)} {expression.value.map((r) => `"${r}"`).join(", ")}
</VariableBox> </VariableBox>
); );
case "lambda": case "lambda":

View File

@ -11,6 +11,7 @@ let triangularDist: DistributionTypes.genericDist = Symbolic(
) )
let exponentialDist: DistributionTypes.genericDist = Symbolic(#Exponential({rate: 2.0})) let exponentialDist: DistributionTypes.genericDist = Symbolic(#Exponential({rate: 2.0}))
let uniformDist: DistributionTypes.genericDist = Symbolic(#Uniform({low: 9.0, high: 10.0})) let uniformDist: DistributionTypes.genericDist = Symbolic(#Uniform({low: 9.0, high: 10.0}))
let uniformDist2: DistributionTypes.genericDist = Symbolic(#Uniform({low: 8.0, high: 11.0}))
let floatDist: DistributionTypes.genericDist = Symbolic(#Float(1e1)) let floatDist: DistributionTypes.genericDist = Symbolic(#Float(1e1))
exception KlFailed exception KlFailed

View File

@ -3,9 +3,15 @@ open Expect
open TestHelpers open TestHelpers
open GenericDist_Fixtures open GenericDist_Fixtures
// integral from low to high of 1 / (high - low) log(normal(mean, stdev)(x) / (1 / (high - low))) dx
let klNormalUniform = (mean, stdev, low, high): float =>
-.Js.Math.log((high -. low) /. Js.Math.sqrt(2.0 *. MagicNumbers.Math.pi *. stdev ** 2.0)) +.
1.0 /.
stdev ** 2.0 *.
(mean ** 2.0 -. (high +. low) *. mean +. (low ** 2.0 +. high *. low +. high ** 2.0) /. 3.0)
describe("klDivergence: continuous -> continuous -> float", () => { describe("klDivergence: continuous -> continuous -> float", () => {
let klDivergence = DistributionOperation.Constructors.klDivergence(~env) let klDivergence = DistributionOperation.Constructors.klDivergence(~env)
exception KlFailed
let testUniform = (lowAnswer, highAnswer, lowPrediction, highPrediction) => { let testUniform = (lowAnswer, highAnswer, lowPrediction, highPrediction) => {
test("of two uniforms is equal to the analytic expression", () => { test("of two uniforms is equal to the analytic expression", () => {
@ -59,6 +65,20 @@ describe("klDivergence: continuous -> continuous -> float", () => {
} }
} }
}) })
test("of a normal and a uniform is equal to the formula", () => {
let prediction = normalDist10
let answer = uniformDist
let kl = klDivergence(prediction, answer)
let analyticalKl = klNormalUniform(10.0, 2.0, 9.0, 10.0)
switch kl {
| Ok(kl') => kl'->expect->toBeSoCloseTo(analyticalKl, ~digits=1)
| Error(err) => {
Js.Console.log(DistributionTypes.Error.toString(err))
raise(KlFailed)
}
}
})
}) })
describe("klDivergence: discrete -> discrete -> float", () => { describe("klDivergence: discrete -> discrete -> float", () => {
@ -96,6 +116,64 @@ describe("klDivergence: discrete -> discrete -> float", () => {
}) })
}) })
describe("klDivergence: mixed -> mixed -> float", () => {
let klDivergence = DistributionOperation.Constructors.klDivergence(~env)
let mixture' = a => DistributionTypes.DistributionOperation.Mixture(a)
let mixture = a => {
let dist' = a->mixture'->run
switch dist' {
| Dist(dist) => dist
| _ => raise(MixtureFailed)
}
}
let a = [(point1, 1.0), (uniformDist, 1.0)]->mixture
let b = [(point1, 1.0), (floatDist, 1.0), (normalDist10, 1.0)]->mixture
let c = [(point1, 1.0), (point2, 1.0), (point3, 1.0), (uniformDist, 1.0)]->mixture
let d =
[(point1, 1.0), (point2, 1.0), (point3, 1.0), (floatDist, 1.0), (uniformDist2, 1.0)]->mixture
test("finite klDivergence produces correct answer", () => {
let prediction = b
let answer = a
let kl = klDivergence(prediction, answer)
// high = 10; low = 9; mean = 10; stdev = 2
let analyticalKlContinuousPart = klNormalUniform(10.0, 2.0, 9.0, 10.0) /. 2.0
let analyticalKlDiscretePart = 1.0 /. 2.0 *. Js.Math.log(2.0 /. 1.0)
switch kl {
| Ok(kl') =>
kl'->expect->toBeSoCloseTo(analyticalKlContinuousPart +. analyticalKlDiscretePart, ~digits=1)
| Error(err) =>
Js.Console.log(DistributionTypes.Error.toString(err))
raise(KlFailed)
}
})
test("returns infinity when infinite", () => {
let prediction = a
let answer = b
let kl = klDivergence(prediction, answer)
switch kl {
| Ok(kl') => kl'->expect->toEqual(infinity)
| Error(err) =>
Js.Console.log(DistributionTypes.Error.toString(err))
raise(KlFailed)
}
})
test("finite klDivergence produces correct answer", () => {
let prediction = d
let answer = c
let kl = klDivergence(prediction, answer)
let analyticalKlContinuousPart = Js.Math.log((11.0 -. 8.0) /. (10.0 -. 9.0)) /. 4.0 // 4 = length of c' array
let analyticalKlDiscretePart = 3.0 /. 4.0 *. Js.Math.log(4.0 /. 3.0)
switch kl {
| Ok(kl') =>
kl'->expect->toBeSoCloseTo(analyticalKlContinuousPart +. analyticalKlDiscretePart, ~digits=1)
| Error(err) =>
Js.Console.log(DistributionTypes.Error.toString(err))
raise(KlFailed)
}
})
})
describe("combineAlongSupportOfSecondArgument0", () => { describe("combineAlongSupportOfSecondArgument0", () => {
// This tests the version of the function that we're NOT using. Haven't deleted the test in case we use the code later. // This tests the version of the function that we're NOT using. Haven't deleted the test in case we use the code later.
test("test on two uniforms", _ => { test("test on two uniforms", _ => {

View File

@ -1,6 +1,6 @@
{ {
"name": "@quri/squiggle-lang", "name": "@quri/squiggle-lang",
"version": "0.2.8", "version": "0.2.9",
"homepage": "https://squiggle-language.com", "homepage": "https://squiggle-language.com",
"license": "MIT", "license": "MIT",
"scripts": { "scripts": {

View File

@ -302,10 +302,9 @@ module T = Dist({
} }
let klDivergence = (prediction: t, answer: t) => { let klDivergence = (prediction: t, answer: t) => {
Error(Operation.NotYetImplemented) let klDiscretePart = Discrete.T.klDivergence(prediction.discrete, answer.discrete)
// combinePointwise(PointSetDist_Scoring.KLDivergence.integrand, prediction, answer) |> E.R.fmap( let klContinuousPart = Continuous.T.klDivergence(prediction.continuous, answer.continuous)
// integralEndY, E.R.merge(klDiscretePart, klContinuousPart)->E.R2.fmap(t => fst(t) +. snd(t))
// )
} }
}) })

View File

@ -200,6 +200,7 @@ module T = Dist({
switch (t1, t2) { switch (t1, t2) {
| (Continuous(t1), Continuous(t2)) => Continuous.T.klDivergence(t1, t2) | (Continuous(t1), Continuous(t2)) => Continuous.T.klDivergence(t1, t2)
| (Discrete(t1), Discrete(t2)) => Discrete.T.klDivergence(t1, t2) | (Discrete(t1), Discrete(t2)) => Discrete.T.klDivergence(t1, t2)
| (Mixed(t1), Mixed(t2)) => Mixed.T.klDivergence(t1, t2)
| _ => Error(NotYetImplemented) | _ => Error(NotYetImplemented)
} }
}) })

View File

@ -1,13 +1,6 @@
module ExpressionValue = ReducerInterface_ExpressionValue module ExpressionValue = ReducerInterface_ExpressionValue
type expressionValue = ReducerInterface_ExpressionValue.expressionValue type expressionValue = ReducerInterface_ExpressionValue.expressionValue
let defaultEnv: DistributionOperation.env = {
sampleCount: MagicNumbers.Environment.defaultSampleCount,
xyPointLength: MagicNumbers.Environment.defaultXYPointLength,
}
let runGenericOperation = DistributionOperation.run(~env=defaultEnv)
module Helpers = { module Helpers = {
let arithmeticMap = r => let arithmeticMap = r =>
switch r { switch r {
@ -39,37 +32,44 @@ module Helpers = {
let toFloatFn = ( let toFloatFn = (
fnCall: DistributionTypes.DistributionOperation.toFloat, fnCall: DistributionTypes.DistributionOperation.toFloat,
dist: DistributionTypes.genericDist, dist: DistributionTypes.genericDist,
~env: DistributionOperation.env,
) => { ) => {
FromDist(DistributionTypes.DistributionOperation.ToFloat(fnCall), dist) FromDist(DistributionTypes.DistributionOperation.ToFloat(fnCall), dist)
->runGenericOperation ->DistributionOperation.run(~env)
->Some ->Some
} }
let toStringFn = ( let toStringFn = (
fnCall: DistributionTypes.DistributionOperation.toString, fnCall: DistributionTypes.DistributionOperation.toString,
dist: DistributionTypes.genericDist, dist: DistributionTypes.genericDist,
~env: DistributionOperation.env,
) => { ) => {
FromDist(DistributionTypes.DistributionOperation.ToString(fnCall), dist) FromDist(DistributionTypes.DistributionOperation.ToString(fnCall), dist)
->runGenericOperation ->DistributionOperation.run(~env)
->Some ->Some
} }
let toBoolFn = ( let toBoolFn = (
fnCall: DistributionTypes.DistributionOperation.toBool, fnCall: DistributionTypes.DistributionOperation.toBool,
dist: DistributionTypes.genericDist, dist: DistributionTypes.genericDist,
~env: DistributionOperation.env,
) => { ) => {
FromDist(DistributionTypes.DistributionOperation.ToBool(fnCall), dist) FromDist(DistributionTypes.DistributionOperation.ToBool(fnCall), dist)
->runGenericOperation ->DistributionOperation.run(~env)
->Some ->Some
} }
let toDistFn = (fnCall: DistributionTypes.DistributionOperation.toDist, dist) => { let toDistFn = (
fnCall: DistributionTypes.DistributionOperation.toDist,
dist,
~env: DistributionOperation.env,
) => {
FromDist(DistributionTypes.DistributionOperation.ToDist(fnCall), dist) FromDist(DistributionTypes.DistributionOperation.ToDist(fnCall), dist)
->runGenericOperation ->DistributionOperation.run(~env)
->Some ->Some
} }
let twoDiststoDistFn = (direction, arithmetic, dist1, dist2) => { let twoDiststoDistFn = (direction, arithmetic, dist1, dist2, ~env: DistributionOperation.env) => {
FromDist( FromDist(
DistributionTypes.DistributionOperation.ToDistCombination( DistributionTypes.DistributionOperation.ToDistCombination(
direction, direction,
@ -77,7 +77,7 @@ module Helpers = {
#Dist(dist2), #Dist(dist2),
), ),
dist1, dist1,
)->runGenericOperation )->DistributionOperation.run(~env)
} }
let parseNumber = (args: expressionValue): Belt.Result.t<float, string> => let parseNumber = (args: expressionValue): Belt.Result.t<float, string> =>
@ -104,33 +104,38 @@ module Helpers = {
let mixtureWithGivenWeights = ( let mixtureWithGivenWeights = (
distributions: array<DistributionTypes.genericDist>, distributions: array<DistributionTypes.genericDist>,
weights: array<float>, weights: array<float>,
~env: DistributionOperation.env,
): DistributionOperation.outputType => ): DistributionOperation.outputType =>
E.A.length(distributions) == E.A.length(weights) E.A.length(distributions) == E.A.length(weights)
? Mixture(Belt.Array.zip(distributions, weights))->runGenericOperation ? Mixture(Belt.Array.zip(distributions, weights))->DistributionOperation.run(~env)
: GenDistError( : GenDistError(
ArgumentError("Error, mixture call has different number of distributions and weights"), ArgumentError("Error, mixture call has different number of distributions and weights"),
) )
let mixtureWithDefaultWeights = ( let mixtureWithDefaultWeights = (
distributions: array<DistributionTypes.genericDist>, distributions: array<DistributionTypes.genericDist>,
~env: DistributionOperation.env,
): DistributionOperation.outputType => { ): DistributionOperation.outputType => {
let length = E.A.length(distributions) let length = E.A.length(distributions)
let weights = Belt.Array.make(length, 1.0 /. Belt.Int.toFloat(length)) let weights = Belt.Array.make(length, 1.0 /. Belt.Int.toFloat(length))
mixtureWithGivenWeights(distributions, weights) mixtureWithGivenWeights(distributions, weights, ~env)
} }
let mixture = (args: array<expressionValue>): DistributionOperation.outputType => { let mixture = (
args: array<expressionValue>,
~env: DistributionOperation.env,
): DistributionOperation.outputType => {
let error = (err: string): DistributionOperation.outputType => let error = (err: string): DistributionOperation.outputType =>
err->DistributionTypes.ArgumentError->GenDistError err->DistributionTypes.ArgumentError->GenDistError
switch args { switch args {
| [EvArray(distributions)] => | [EvArray(distributions)] =>
switch parseDistributionArray(distributions) { switch parseDistributionArray(distributions) {
| Ok(distrs) => mixtureWithDefaultWeights(distrs) | Ok(distrs) => mixtureWithDefaultWeights(distrs, ~env)
| Error(err) => error(err) | Error(err) => error(err)
} }
| [EvArray(distributions), EvArray(weights)] => | [EvArray(distributions), EvArray(weights)] =>
switch (parseDistributionArray(distributions), parseNumberArray(weights)) { switch (parseDistributionArray(distributions), parseNumberArray(weights)) {
| (Ok(distrs), Ok(wghts)) => mixtureWithGivenWeights(distrs, wghts) | (Ok(distrs), Ok(wghts)) => mixtureWithGivenWeights(distrs, wghts, ~env)
| (Error(err), Ok(_)) => error(err) | (Error(err), Ok(_)) => error(err)
| (Ok(_), Error(err)) => error(err) | (Ok(_), Error(err)) => error(err)
| (Error(err1), Error(err2)) => error(`${err1}|${err2}`) | (Error(err1), Error(err2)) => error(`${err1}|${err2}`)
@ -143,14 +148,14 @@ module Helpers = {
Belt.Array.slice(args, ~offset=0, ~len=E.A.length(args) - 1), Belt.Array.slice(args, ~offset=0, ~len=E.A.length(args) - 1),
) )
switch E.R.merge(distributions, weights) { switch E.R.merge(distributions, weights) {
| Ok(d, w) => mixtureWithGivenWeights(d, w) | Ok(d, w) => mixtureWithGivenWeights(d, w, ~env)
| Error(err) => error(err) | Error(err) => error(err)
} }
} }
| Some(EvNumber(_)) | Some(EvNumber(_))
| Some(EvDistribution(_)) => | Some(EvDistribution(_)) =>
switch parseDistributionArray(args) { switch parseDistributionArray(args) {
| Ok(distributions) => mixtureWithDefaultWeights(distributions) | Ok(distributions) => mixtureWithDefaultWeights(distributions, ~env)
| Error(err) => error(err) | Error(err) => error(err)
} }
| _ => error("Last argument of mx must be array or distribution") | _ => error("Last argument of mx must be array or distribution")
@ -193,9 +198,10 @@ module SymbolicConstructors = {
} }
} }
let dispatchToGenericOutput = (call: ExpressionValue.functionCall, _environment): option< let dispatchToGenericOutput = (
DistributionOperation.outputType, call: ExpressionValue.functionCall,
> => { env: DistributionOperation.env,
): option<DistributionOperation.outputType> => {
let (fnName, args) = call let (fnName, args) = call
switch (fnName, args) { switch (fnName, args) {
| ("exponential" as fnName, [EvNumber(f)]) => | ("exponential" as fnName, [EvNumber(f)]) =>
@ -215,16 +221,16 @@ let dispatchToGenericOutput = (call: ExpressionValue.functionCall, _environment)
SymbolicConstructors.threeFloat(fnName) SymbolicConstructors.threeFloat(fnName)
->E.R.bind(r => r(f1, f2, f3)) ->E.R.bind(r => r(f1, f2, f3))
->SymbolicConstructors.symbolicResultToOutput ->SymbolicConstructors.symbolicResultToOutput
| ("sample", [EvDistribution(dist)]) => Helpers.toFloatFn(#Sample, dist) | ("sample", [EvDistribution(dist)]) => Helpers.toFloatFn(#Sample, dist, ~env)
| ("sampleN", [EvDistribution(dist), EvNumber(n)]) => Some( | ("sampleN", [EvDistribution(dist), EvNumber(n)]) => Some(
FloatArray(GenericDist.sampleN(dist, Belt.Int.fromFloat(n))), FloatArray(GenericDist.sampleN(dist, Belt.Int.fromFloat(n))),
) )
| ("mean", [EvDistribution(dist)]) => Helpers.toFloatFn(#Mean, dist) | ("mean", [EvDistribution(dist)]) => Helpers.toFloatFn(#Mean, dist, ~env)
| ("integralSum", [EvDistribution(dist)]) => Helpers.toFloatFn(#IntegralSum, dist) | ("integralSum", [EvDistribution(dist)]) => Helpers.toFloatFn(#IntegralSum, dist, ~env)
| ("toString", [EvDistribution(dist)]) => Helpers.toStringFn(ToString, dist) | ("toString", [EvDistribution(dist)]) => Helpers.toStringFn(ToString, dist, ~env)
| ("toSparkline", [EvDistribution(dist)]) => Helpers.toStringFn(ToSparkline(20), dist) | ("toSparkline", [EvDistribution(dist)]) => Helpers.toStringFn(ToSparkline(20), dist, ~env)
| ("toSparkline", [EvDistribution(dist), EvNumber(n)]) => | ("toSparkline", [EvDistribution(dist), EvNumber(n)]) =>
Helpers.toStringFn(ToSparkline(Belt.Float.toInt(n)), dist) Helpers.toStringFn(ToSparkline(Belt.Float.toInt(n)), dist, ~env)
| ("exp", [EvDistribution(a)]) => | ("exp", [EvDistribution(a)]) =>
// https://mathjs.org/docs/reference/functions/exp.html // https://mathjs.org/docs/reference/functions/exp.html
Helpers.twoDiststoDistFn( Helpers.twoDiststoDistFn(
@ -232,60 +238,74 @@ let dispatchToGenericOutput = (call: ExpressionValue.functionCall, _environment)
"pow", "pow",
GenericDist.fromFloat(MagicNumbers.Math.e), GenericDist.fromFloat(MagicNumbers.Math.e),
a, a,
~env,
)->Some )->Some
| ("normalize", [EvDistribution(dist)]) => Helpers.toDistFn(Normalize, dist) | ("normalize", [EvDistribution(dist)]) => Helpers.toDistFn(Normalize, dist, ~env)
| ("klDivergence", [EvDistribution(a), EvDistribution(b)]) => | ("klDivergence", [EvDistribution(a), EvDistribution(b)]) =>
Some(runGenericOperation(FromDist(ToScore(KLDivergence(b)), a))) Some(DistributionOperation.run(FromDist(ToScore(KLDivergence(b)), a), ~env))
| ("isNormalized", [EvDistribution(dist)]) => Helpers.toBoolFn(IsNormalized, dist) | ("isNormalized", [EvDistribution(dist)]) => Helpers.toBoolFn(IsNormalized, dist, ~env)
| ("toPointSet", [EvDistribution(dist)]) => Helpers.toDistFn(ToPointSet, dist) | ("toPointSet", [EvDistribution(dist)]) => Helpers.toDistFn(ToPointSet, dist, ~env)
| ("scaleLog", [EvDistribution(dist)]) => | ("scaleLog", [EvDistribution(dist)]) =>
Helpers.toDistFn(Scale(#Logarithm, MagicNumbers.Math.e), dist) Helpers.toDistFn(Scale(#Logarithm, MagicNumbers.Math.e), dist, ~env)
| ("scaleLog10", [EvDistribution(dist)]) => Helpers.toDistFn(Scale(#Logarithm, 10.0), dist) | ("scaleLog10", [EvDistribution(dist)]) => Helpers.toDistFn(Scale(#Logarithm, 10.0), dist, ~env)
| ("scaleLog", [EvDistribution(dist), EvNumber(float)]) => | ("scaleLog", [EvDistribution(dist), EvNumber(float)]) =>
Helpers.toDistFn(Scale(#Logarithm, float), dist) Helpers.toDistFn(Scale(#Logarithm, float), dist, ~env)
| ("scaleLogWithThreshold", [EvDistribution(dist), EvNumber(base), EvNumber(eps)]) => | ("scaleLogWithThreshold", [EvDistribution(dist), EvNumber(base), EvNumber(eps)]) =>
Helpers.toDistFn(Scale(#LogarithmWithThreshold(eps), base), dist) Helpers.toDistFn(Scale(#LogarithmWithThreshold(eps), base), dist, ~env)
| ("scalePow", [EvDistribution(dist), EvNumber(float)]) => | ("scalePow", [EvDistribution(dist), EvNumber(float)]) =>
Helpers.toDistFn(Scale(#Power, float), dist) Helpers.toDistFn(Scale(#Power, float), dist, ~env)
| ("scaleExp", [EvDistribution(dist)]) => | ("scaleExp", [EvDistribution(dist)]) =>
Helpers.toDistFn(Scale(#Power, MagicNumbers.Math.e), dist) Helpers.toDistFn(Scale(#Power, MagicNumbers.Math.e), dist, ~env)
| ("cdf", [EvDistribution(dist), EvNumber(float)]) => Helpers.toFloatFn(#Cdf(float), dist) | ("cdf", [EvDistribution(dist), EvNumber(float)]) => Helpers.toFloatFn(#Cdf(float), dist, ~env)
| ("pdf", [EvDistribution(dist), EvNumber(float)]) => Helpers.toFloatFn(#Pdf(float), dist) | ("pdf", [EvDistribution(dist), EvNumber(float)]) => Helpers.toFloatFn(#Pdf(float), dist, ~env)
| ("inv", [EvDistribution(dist), EvNumber(float)]) => Helpers.toFloatFn(#Inv(float), dist) | ("inv", [EvDistribution(dist), EvNumber(float)]) => Helpers.toFloatFn(#Inv(float), dist, ~env)
| ("toSampleSet", [EvDistribution(dist), EvNumber(float)]) => | ("toSampleSet", [EvDistribution(dist), EvNumber(float)]) =>
Helpers.toDistFn(ToSampleSet(Belt.Int.fromFloat(float)), dist) Helpers.toDistFn(ToSampleSet(Belt.Int.fromFloat(float)), dist, ~env)
| ("toSampleSet", [EvDistribution(dist)]) => | ("toSampleSet", [EvDistribution(dist)]) =>
Helpers.toDistFn(ToSampleSet(MagicNumbers.Environment.defaultSampleCount), dist) Helpers.toDistFn(ToSampleSet(env.sampleCount), dist, ~env)
| ("fromSamples", [EvArray(inputArray)]) => { | ("fromSamples", [EvArray(inputArray)]) => {
let _wrapInputErrors = x => SampleSetDist.NonNumericInput(x) let _wrapInputErrors = x => SampleSetDist.NonNumericInput(x)
let parsedArray = Helpers.parseNumberArray(inputArray)->E.R2.errMap(_wrapInputErrors) let parsedArray = Helpers.parseNumberArray(inputArray)->E.R2.errMap(_wrapInputErrors)
switch parsedArray { switch parsedArray {
| Ok(array) => runGenericOperation(FromSamples(array)) | Ok(array) => DistributionOperation.run(FromSamples(array), ~env)
| Error(e) => GenDistError(SampleSetError(e)) | Error(e) => GenDistError(SampleSetError(e))
}->Some }->Some
} }
| ("inspect", [EvDistribution(dist)]) => Helpers.toDistFn(Inspect, dist) | ("inspect", [EvDistribution(dist)]) => Helpers.toDistFn(Inspect, dist, ~env)
| ("truncateLeft", [EvDistribution(dist), EvNumber(float)]) => | ("truncateLeft", [EvDistribution(dist), EvNumber(float)]) =>
Helpers.toDistFn(Truncate(Some(float), None), dist) Helpers.toDistFn(Truncate(Some(float), None), dist, ~env)
| ("truncateRight", [EvDistribution(dist), EvNumber(float)]) => | ("truncateRight", [EvDistribution(dist), EvNumber(float)]) =>
Helpers.toDistFn(Truncate(None, Some(float)), dist) Helpers.toDistFn(Truncate(None, Some(float)), dist, ~env)
| ("truncate", [EvDistribution(dist), EvNumber(float1), EvNumber(float2)]) => | ("truncate", [EvDistribution(dist), EvNumber(float1), EvNumber(float2)]) =>
Helpers.toDistFn(Truncate(Some(float1), Some(float2)), dist) Helpers.toDistFn(Truncate(Some(float1), Some(float2)), dist, ~env)
| ("mx" | "mixture", args) => Helpers.mixture(args)->Some | ("mx" | "mixture", args) => Helpers.mixture(args, ~env)->Some
| ("log", [EvDistribution(a)]) => | ("log", [EvDistribution(a)]) =>
Helpers.twoDiststoDistFn( Helpers.twoDiststoDistFn(
Algebraic(AsDefault), Algebraic(AsDefault),
"log", "log",
a, a,
GenericDist.fromFloat(MagicNumbers.Math.e), GenericDist.fromFloat(MagicNumbers.Math.e),
~env,
)->Some )->Some
| ("log10", [EvDistribution(a)]) => | ("log10", [EvDistribution(a)]) =>
Helpers.twoDiststoDistFn(Algebraic(AsDefault), "log", a, GenericDist.fromFloat(10.0))->Some Helpers.twoDiststoDistFn(
Algebraic(AsDefault),
"log",
a,
GenericDist.fromFloat(10.0),
~env,
)->Some
| ("unaryMinus", [EvDistribution(a)]) => | ("unaryMinus", [EvDistribution(a)]) =>
Helpers.twoDiststoDistFn(Algebraic(AsDefault), "multiply", a, GenericDist.fromFloat(-1.0))->Some Helpers.twoDiststoDistFn(
Algebraic(AsDefault),
"multiply",
a,
GenericDist.fromFloat(-1.0),
~env,
)->Some
| (("add" | "multiply" | "subtract" | "divide" | "pow" | "log") as arithmetic, [_, _] as args) => | (("add" | "multiply" | "subtract" | "divide" | "pow" | "log") as arithmetic, [_, _] as args) =>
Helpers.catchAndConvertTwoArgsToDists(args)->E.O2.fmap(((fst, snd)) => Helpers.catchAndConvertTwoArgsToDists(args)->E.O2.fmap(((fst, snd)) =>
Helpers.twoDiststoDistFn(Algebraic(AsDefault), arithmetic, fst, snd) Helpers.twoDiststoDistFn(Algebraic(AsDefault), arithmetic, fst, snd, ~env)
) )
| ( | (
("dotAdd" ("dotAdd"
@ -296,7 +316,7 @@ let dispatchToGenericOutput = (call: ExpressionValue.functionCall, _environment)
[_, _] as args, [_, _] as args,
) => ) =>
Helpers.catchAndConvertTwoArgsToDists(args)->E.O2.fmap(((fst, snd)) => Helpers.catchAndConvertTwoArgsToDists(args)->E.O2.fmap(((fst, snd)) =>
Helpers.twoDiststoDistFn(Pointwise, arithmetic, fst, snd) Helpers.twoDiststoDistFn(Pointwise, arithmetic, fst, snd, ~env)
) )
| ("dotExp", [EvDistribution(a)]) => | ("dotExp", [EvDistribution(a)]) =>
Helpers.twoDiststoDistFn( Helpers.twoDiststoDistFn(
@ -304,6 +324,7 @@ let dispatchToGenericOutput = (call: ExpressionValue.functionCall, _environment)
"dotPow", "dotPow",
GenericDist.fromFloat(MagicNumbers.Math.e), GenericDist.fromFloat(MagicNumbers.Math.e),
a, a,
~env,
)->Some )->Some
| _ => None | _ => None
} }

View File

@ -1,4 +1,3 @@
let defaultEnv: DistributionOperation.env
let dispatch: ( let dispatch: (
ReducerInterface_ExpressionValue.functionCall, ReducerInterface_ExpressionValue.functionCall,
ReducerInterface_ExpressionValue.environment, ReducerInterface_ExpressionValue.environment,

View File

@ -77,7 +77,7 @@ let distributionErrorToString = DistributionTypes.Error.toString
type lambdaValue = ReducerInterface_ExpressionValue.lambdaValue type lambdaValue = ReducerInterface_ExpressionValue.lambdaValue
@genType @genType
let defaultSamplingEnv = ReducerInterface_GenericDistribution.defaultEnv let defaultSamplingEnv = DistributionOperation.defaultEnv
@genType @genType
type environment = ReducerInterface_ExpressionValue.environment type environment = ReducerInterface_ExpressionValue.environment

View File

@ -453,6 +453,44 @@ module PointwiseCombination = {
T.filterOkYs(newXs, newYs)->Ok T.filterOkYs(newXs, newYs)->Ok
} }
/* *Dead code*: Nuño wrote this function to try to increase precision, but it didn't work.
If another traveler comes through with a similar idea, we hope this implementation will help them.
By "enrich" we mean to increase granularity.
*/
let enrichXyShape = (t: T.t): T.t => {
let defaultEnrichmentFactor = 10
let length = E.A.length(t.xs)
let points =
length < MagicNumbers.Environment.defaultXYPointLength
? defaultEnrichmentFactor * MagicNumbers.Environment.defaultXYPointLength / length
: defaultEnrichmentFactor
let getInBetween = (x1: float, x2: float): array<float> => {
if abs_float(x1 -. x2) < 2.0 *. MagicNumbers.Epsilon.seven {
[x1]
} else {
let newPointsArray = Belt.Array.makeBy(points - 1, i => i)
// don't repeat the x2 point, it will be gotten in the next iteration.
let result = Js.Array.mapi((pos, i) =>
if i == 0 {
x1
} else {
let points' = Belt.Float.fromInt(points)
let pos' = Belt.Float.fromInt(pos)
x1 *. (points' -. pos') /. points' +. x2 *. pos' /. points'
}
, newPointsArray)
result
}
}
let newXsUnflattened = Js.Array.mapi(
(x, i) => i < length - 2 ? getInBetween(x, t.xs[i + 1]) : [x],
t.xs,
)
let newXs = Belt.Array.concatMany(newXsUnflattened)
let newYs = E.A.fmap(x => XtoY.linear(x, t), newXs)
{xs: newXs, ys: newYs}
}
// This function is used for klDivergence // This function is used for klDivergence
let combineAlongSupportOfSecondArgument: ( let combineAlongSupportOfSecondArgument: (
(float, float) => result<float, Operation.Error.t>, (float, float) => result<float, Operation.Error.t>,

View File

@ -3726,10 +3726,10 @@
"@testing-library/dom" "^8.5.0" "@testing-library/dom" "^8.5.0"
"@types/react-dom" "^18.0.0" "@types/react-dom" "^18.0.0"
"@testing-library/user-event@^14.1.1": "@testing-library/user-event@^14.2.0":
version "14.1.1" version "14.2.0"
resolved "https://registry.yarnpkg.com/@testing-library/user-event/-/user-event-14.1.1.tgz#e1ff6118896e4b22af31e5ea2f9da956adde23d8" resolved "https://registry.yarnpkg.com/@testing-library/user-event/-/user-event-14.2.0.tgz#8293560f8f80a00383d6c755ec3e0b918acb1683"
integrity sha512-XrjH/iEUqNl9lF2HX9YhPNV7Amntkcnpw0Bo1KkRzowNDcgSN9i0nm4Q8Oi5wupgdfPaJNMAWa61A+voD6Kmwg== integrity sha512-+hIlG4nJS6ivZrKnOP7OGsDu9Fxmryj9vCl8x0ZINtTJcCHs2zLsYif5GzuRiBF2ck5GZG2aQr7Msg+EHlnYVQ==
"@tootallnate/once@1": "@tootallnate/once@1":
version "1.1.2" version "1.1.2"
@ -4038,10 +4038,10 @@
"@types/node" "*" "@types/node" "*"
form-data "^3.0.0" form-data "^3.0.0"
"@types/node@*", "@types/node@^17.0.31", "@types/node@^17.0.5": "@types/node@*", "@types/node@^17.0.32", "@types/node@^17.0.5":
version "17.0.31" version "17.0.32"
resolved "https://registry.yarnpkg.com/@types/node/-/node-17.0.31.tgz#a5bb84ecfa27eec5e1c802c6bbf8139bdb163a5d" resolved "https://registry.yarnpkg.com/@types/node/-/node-17.0.32.tgz#51d59d7a90ef2d0ae961791e0900cad2393a0149"
integrity sha512-AR0x5HbXGqkEx9CadRH3EBYx/VkiUgZIhP4wvPn/+5KIsgpNoyFaRlVe0Zlx9gRtg8fA06a9tskE2MSN7TcG4Q== integrity sha512-eAIcfAvhf/BkHcf4pkLJ7ECpBAhh9kcxRBpip9cTiO+hf+aJrsxYxBeS6OXvOd9WqNAJmavXVpZvY1rBjNsXmw==
"@types/node@^14.0.10": "@types/node@^14.0.10":
version "14.18.16" version "14.18.16"