Merged with master
This commit is contained in:
commit
47d7ef49cf
|
@ -28,10 +28,10 @@
|
||||||
"@storybook/react": "^6.4.22",
|
"@storybook/react": "^6.4.22",
|
||||||
"@testing-library/jest-dom": "^5.16.4",
|
"@testing-library/jest-dom": "^5.16.4",
|
||||||
"@testing-library/react": "^13.2.0",
|
"@testing-library/react": "^13.2.0",
|
||||||
"@testing-library/user-event": "^14.1.1",
|
"@testing-library/user-event": "^14.2.0",
|
||||||
"@types/jest": "^27.5.0",
|
"@types/jest": "^27.5.0",
|
||||||
"@types/lodash": "^4.14.182",
|
"@types/lodash": "^4.14.182",
|
||||||
"@types/node": "^17.0.31",
|
"@types/node": "^17.0.32",
|
||||||
"@types/react": "^18.0.9",
|
"@types/react": "^18.0.9",
|
||||||
"@types/react-dom": "^18.0.2",
|
"@types/react-dom": "^18.0.2",
|
||||||
"@types/styled-components": "^5.1.24",
|
"@types/styled-components": "^5.1.24",
|
||||||
|
|
|
@ -76,7 +76,9 @@ export const FunctionChart: React.FC<FunctionChartProps> = ({
|
||||||
chartSettings.count
|
chartSettings.count
|
||||||
);
|
);
|
||||||
type point = { x: number; value: result<Distribution, string> };
|
type point = { x: number; value: result<Distribution, string> };
|
||||||
let valueData: point[] = data1.map((x) => {
|
let valueData: point[] = React.useMemo(
|
||||||
|
() =>
|
||||||
|
data1.map((x) => {
|
||||||
let result = runForeign(fn, [x], environment);
|
let result = runForeign(fn, [x], environment);
|
||||||
if (result.tag === "Ok") {
|
if (result.tag === "Ok") {
|
||||||
if (result.value.tag == "distribution") {
|
if (result.value.tag == "distribution") {
|
||||||
|
@ -97,7 +99,9 @@ export const FunctionChart: React.FC<FunctionChartProps> = ({
|
||||||
value: { tag: "Error", value: errorValueToString(result.value) },
|
value: { tag: "Error", value: errorValueToString(result.value) },
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
});
|
}),
|
||||||
|
[environment, fn]
|
||||||
|
);
|
||||||
|
|
||||||
let initialPartition: [
|
let initialPartition: [
|
||||||
{ x: number; value: Distribution }[],
|
{ x: number; value: Distribution }[],
|
||||||
|
@ -141,10 +145,10 @@ export const FunctionChart: React.FC<FunctionChartProps> = ({
|
||||||
/>
|
/>
|
||||||
{showChart}
|
{showChart}
|
||||||
{_.entries(groupedErrors).map(([errorName, errorPoints]) => (
|
{_.entries(groupedErrors).map(([errorName, errorPoints]) => (
|
||||||
<ErrorBox heading={errorName}>
|
<ErrorBox key={errorName} heading={errorName}>
|
||||||
Values:{" "}
|
Values:{" "}
|
||||||
{errorPoints
|
{errorPoints
|
||||||
.map((r) => <NumberShower number={r.x} />)
|
.map((r, i) => <NumberShower key={i} number={r.x} />)
|
||||||
.reduce((a, b) => (
|
.reduce((a, b) => (
|
||||||
<>
|
<>
|
||||||
{a}, {b}
|
{a}, {b}
|
||||||
|
|
|
@ -148,8 +148,9 @@ const SquiggleItem: React.FC<SquiggleItemProps> = ({
|
||||||
case "array":
|
case "array":
|
||||||
return (
|
return (
|
||||||
<VariableBox heading="Array" showTypes={showTypes}>
|
<VariableBox heading="Array" showTypes={showTypes}>
|
||||||
{expression.value.map((r) => (
|
{expression.value.map((r, i) => (
|
||||||
<SquiggleItem
|
<SquiggleItem
|
||||||
|
key={i}
|
||||||
expression={r}
|
expression={r}
|
||||||
width={width !== undefined ? width - 20 : width}
|
width={width !== undefined ? width - 20 : width}
|
||||||
height={50}
|
height={50}
|
||||||
|
@ -166,7 +167,7 @@ const SquiggleItem: React.FC<SquiggleItemProps> = ({
|
||||||
return (
|
return (
|
||||||
<VariableBox heading="Record" showTypes={showTypes}>
|
<VariableBox heading="Record" showTypes={showTypes}>
|
||||||
{Object.entries(expression.value).map(([key, r]) => (
|
{Object.entries(expression.value).map(([key, r]) => (
|
||||||
<>
|
<div key={key}>
|
||||||
<RecordKeyHeader>{key}</RecordKeyHeader>
|
<RecordKeyHeader>{key}</RecordKeyHeader>
|
||||||
<SquiggleItem
|
<SquiggleItem
|
||||||
expression={r}
|
expression={r}
|
||||||
|
@ -178,14 +179,14 @@ const SquiggleItem: React.FC<SquiggleItemProps> = ({
|
||||||
chartSettings={chartSettings}
|
chartSettings={chartSettings}
|
||||||
environment={environment}
|
environment={environment}
|
||||||
/>
|
/>
|
||||||
</>
|
</div>
|
||||||
))}
|
))}
|
||||||
</VariableBox>
|
</VariableBox>
|
||||||
);
|
);
|
||||||
case "arraystring":
|
case "arraystring":
|
||||||
return (
|
return (
|
||||||
<VariableBox heading="Array String" showTypes={showTypes}>
|
<VariableBox heading="Array String" showTypes={showTypes}>
|
||||||
{expression.value.map((r) => `"${r}"`)}
|
{expression.value.map((r) => `"${r}"`).join(", ")}
|
||||||
</VariableBox>
|
</VariableBox>
|
||||||
);
|
);
|
||||||
case "lambda":
|
case "lambda":
|
||||||
|
|
|
@ -11,6 +11,7 @@ let triangularDist: DistributionTypes.genericDist = Symbolic(
|
||||||
)
|
)
|
||||||
let exponentialDist: DistributionTypes.genericDist = Symbolic(#Exponential({rate: 2.0}))
|
let exponentialDist: DistributionTypes.genericDist = Symbolic(#Exponential({rate: 2.0}))
|
||||||
let uniformDist: DistributionTypes.genericDist = Symbolic(#Uniform({low: 9.0, high: 10.0}))
|
let uniformDist: DistributionTypes.genericDist = Symbolic(#Uniform({low: 9.0, high: 10.0}))
|
||||||
|
let uniformDist2: DistributionTypes.genericDist = Symbolic(#Uniform({low: 8.0, high: 11.0}))
|
||||||
let floatDist: DistributionTypes.genericDist = Symbolic(#Float(1e1))
|
let floatDist: DistributionTypes.genericDist = Symbolic(#Float(1e1))
|
||||||
|
|
||||||
exception KlFailed
|
exception KlFailed
|
||||||
|
|
|
@ -3,9 +3,15 @@ open Expect
|
||||||
open TestHelpers
|
open TestHelpers
|
||||||
open GenericDist_Fixtures
|
open GenericDist_Fixtures
|
||||||
|
|
||||||
|
// integral from low to high of 1 / (high - low) log(normal(mean, stdev)(x) / (1 / (high - low))) dx
|
||||||
|
let klNormalUniform = (mean, stdev, low, high): float =>
|
||||||
|
-.Js.Math.log((high -. low) /. Js.Math.sqrt(2.0 *. MagicNumbers.Math.pi *. stdev ** 2.0)) +.
|
||||||
|
1.0 /.
|
||||||
|
stdev ** 2.0 *.
|
||||||
|
(mean ** 2.0 -. (high +. low) *. mean +. (low ** 2.0 +. high *. low +. high ** 2.0) /. 3.0)
|
||||||
|
|
||||||
describe("klDivergence: continuous -> continuous -> float", () => {
|
describe("klDivergence: continuous -> continuous -> float", () => {
|
||||||
let klDivergence = DistributionOperation.Constructors.klDivergence(~env)
|
let klDivergence = DistributionOperation.Constructors.klDivergence(~env)
|
||||||
exception KlFailed
|
|
||||||
|
|
||||||
let testUniform = (lowAnswer, highAnswer, lowPrediction, highPrediction) => {
|
let testUniform = (lowAnswer, highAnswer, lowPrediction, highPrediction) => {
|
||||||
test("of two uniforms is equal to the analytic expression", () => {
|
test("of two uniforms is equal to the analytic expression", () => {
|
||||||
|
@ -59,6 +65,20 @@ describe("klDivergence: continuous -> continuous -> float", () => {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
|
test("of a normal and a uniform is equal to the formula", () => {
|
||||||
|
let prediction = normalDist10
|
||||||
|
let answer = uniformDist
|
||||||
|
let kl = klDivergence(prediction, answer)
|
||||||
|
let analyticalKl = klNormalUniform(10.0, 2.0, 9.0, 10.0)
|
||||||
|
switch kl {
|
||||||
|
| Ok(kl') => kl'->expect->toBeSoCloseTo(analyticalKl, ~digits=1)
|
||||||
|
| Error(err) => {
|
||||||
|
Js.Console.log(DistributionTypes.Error.toString(err))
|
||||||
|
raise(KlFailed)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
})
|
})
|
||||||
|
|
||||||
describe("klDivergence: discrete -> discrete -> float", () => {
|
describe("klDivergence: discrete -> discrete -> float", () => {
|
||||||
|
@ -96,6 +116,64 @@ describe("klDivergence: discrete -> discrete -> float", () => {
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|
||||||
|
describe("klDivergence: mixed -> mixed -> float", () => {
|
||||||
|
let klDivergence = DistributionOperation.Constructors.klDivergence(~env)
|
||||||
|
let mixture' = a => DistributionTypes.DistributionOperation.Mixture(a)
|
||||||
|
let mixture = a => {
|
||||||
|
let dist' = a->mixture'->run
|
||||||
|
switch dist' {
|
||||||
|
| Dist(dist) => dist
|
||||||
|
| _ => raise(MixtureFailed)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
let a = [(point1, 1.0), (uniformDist, 1.0)]->mixture
|
||||||
|
let b = [(point1, 1.0), (floatDist, 1.0), (normalDist10, 1.0)]->mixture
|
||||||
|
let c = [(point1, 1.0), (point2, 1.0), (point3, 1.0), (uniformDist, 1.0)]->mixture
|
||||||
|
let d =
|
||||||
|
[(point1, 1.0), (point2, 1.0), (point3, 1.0), (floatDist, 1.0), (uniformDist2, 1.0)]->mixture
|
||||||
|
|
||||||
|
test("finite klDivergence produces correct answer", () => {
|
||||||
|
let prediction = b
|
||||||
|
let answer = a
|
||||||
|
let kl = klDivergence(prediction, answer)
|
||||||
|
// high = 10; low = 9; mean = 10; stdev = 2
|
||||||
|
let analyticalKlContinuousPart = klNormalUniform(10.0, 2.0, 9.0, 10.0) /. 2.0
|
||||||
|
let analyticalKlDiscretePart = 1.0 /. 2.0 *. Js.Math.log(2.0 /. 1.0)
|
||||||
|
switch kl {
|
||||||
|
| Ok(kl') =>
|
||||||
|
kl'->expect->toBeSoCloseTo(analyticalKlContinuousPart +. analyticalKlDiscretePart, ~digits=1)
|
||||||
|
| Error(err) =>
|
||||||
|
Js.Console.log(DistributionTypes.Error.toString(err))
|
||||||
|
raise(KlFailed)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
test("returns infinity when infinite", () => {
|
||||||
|
let prediction = a
|
||||||
|
let answer = b
|
||||||
|
let kl = klDivergence(prediction, answer)
|
||||||
|
switch kl {
|
||||||
|
| Ok(kl') => kl'->expect->toEqual(infinity)
|
||||||
|
| Error(err) =>
|
||||||
|
Js.Console.log(DistributionTypes.Error.toString(err))
|
||||||
|
raise(KlFailed)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
test("finite klDivergence produces correct answer", () => {
|
||||||
|
let prediction = d
|
||||||
|
let answer = c
|
||||||
|
let kl = klDivergence(prediction, answer)
|
||||||
|
let analyticalKlContinuousPart = Js.Math.log((11.0 -. 8.0) /. (10.0 -. 9.0)) /. 4.0 // 4 = length of c' array
|
||||||
|
let analyticalKlDiscretePart = 3.0 /. 4.0 *. Js.Math.log(4.0 /. 3.0)
|
||||||
|
switch kl {
|
||||||
|
| Ok(kl') =>
|
||||||
|
kl'->expect->toBeSoCloseTo(analyticalKlContinuousPart +. analyticalKlDiscretePart, ~digits=1)
|
||||||
|
| Error(err) =>
|
||||||
|
Js.Console.log(DistributionTypes.Error.toString(err))
|
||||||
|
raise(KlFailed)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
describe("combineAlongSupportOfSecondArgument0", () => {
|
describe("combineAlongSupportOfSecondArgument0", () => {
|
||||||
// This tests the version of the function that we're NOT using. Haven't deleted the test in case we use the code later.
|
// This tests the version of the function that we're NOT using. Haven't deleted the test in case we use the code later.
|
||||||
test("test on two uniforms", _ => {
|
test("test on two uniforms", _ => {
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
{
|
{
|
||||||
"name": "@quri/squiggle-lang",
|
"name": "@quri/squiggle-lang",
|
||||||
"version": "0.2.8",
|
"version": "0.2.9",
|
||||||
"homepage": "https://squiggle-language.com",
|
"homepage": "https://squiggle-language.com",
|
||||||
"license": "MIT",
|
"license": "MIT",
|
||||||
"scripts": {
|
"scripts": {
|
||||||
|
|
|
@ -302,10 +302,9 @@ module T = Dist({
|
||||||
}
|
}
|
||||||
|
|
||||||
let klDivergence = (prediction: t, answer: t) => {
|
let klDivergence = (prediction: t, answer: t) => {
|
||||||
Error(Operation.NotYetImplemented)
|
let klDiscretePart = Discrete.T.klDivergence(prediction.discrete, answer.discrete)
|
||||||
// combinePointwise(PointSetDist_Scoring.KLDivergence.integrand, prediction, answer) |> E.R.fmap(
|
let klContinuousPart = Continuous.T.klDivergence(prediction.continuous, answer.continuous)
|
||||||
// integralEndY,
|
E.R.merge(klDiscretePart, klContinuousPart)->E.R2.fmap(t => fst(t) +. snd(t))
|
||||||
// )
|
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|
|
@ -200,6 +200,7 @@ module T = Dist({
|
||||||
switch (t1, t2) {
|
switch (t1, t2) {
|
||||||
| (Continuous(t1), Continuous(t2)) => Continuous.T.klDivergence(t1, t2)
|
| (Continuous(t1), Continuous(t2)) => Continuous.T.klDivergence(t1, t2)
|
||||||
| (Discrete(t1), Discrete(t2)) => Discrete.T.klDivergence(t1, t2)
|
| (Discrete(t1), Discrete(t2)) => Discrete.T.klDivergence(t1, t2)
|
||||||
|
| (Mixed(t1), Mixed(t2)) => Mixed.T.klDivergence(t1, t2)
|
||||||
| _ => Error(NotYetImplemented)
|
| _ => Error(NotYetImplemented)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
|
@ -1,13 +1,6 @@
|
||||||
module ExpressionValue = ReducerInterface_ExpressionValue
|
module ExpressionValue = ReducerInterface_ExpressionValue
|
||||||
type expressionValue = ReducerInterface_ExpressionValue.expressionValue
|
type expressionValue = ReducerInterface_ExpressionValue.expressionValue
|
||||||
|
|
||||||
let defaultEnv: DistributionOperation.env = {
|
|
||||||
sampleCount: MagicNumbers.Environment.defaultSampleCount,
|
|
||||||
xyPointLength: MagicNumbers.Environment.defaultXYPointLength,
|
|
||||||
}
|
|
||||||
|
|
||||||
let runGenericOperation = DistributionOperation.run(~env=defaultEnv)
|
|
||||||
|
|
||||||
module Helpers = {
|
module Helpers = {
|
||||||
let arithmeticMap = r =>
|
let arithmeticMap = r =>
|
||||||
switch r {
|
switch r {
|
||||||
|
@ -39,37 +32,44 @@ module Helpers = {
|
||||||
let toFloatFn = (
|
let toFloatFn = (
|
||||||
fnCall: DistributionTypes.DistributionOperation.toFloat,
|
fnCall: DistributionTypes.DistributionOperation.toFloat,
|
||||||
dist: DistributionTypes.genericDist,
|
dist: DistributionTypes.genericDist,
|
||||||
|
~env: DistributionOperation.env,
|
||||||
) => {
|
) => {
|
||||||
FromDist(DistributionTypes.DistributionOperation.ToFloat(fnCall), dist)
|
FromDist(DistributionTypes.DistributionOperation.ToFloat(fnCall), dist)
|
||||||
->runGenericOperation
|
->DistributionOperation.run(~env)
|
||||||
->Some
|
->Some
|
||||||
}
|
}
|
||||||
|
|
||||||
let toStringFn = (
|
let toStringFn = (
|
||||||
fnCall: DistributionTypes.DistributionOperation.toString,
|
fnCall: DistributionTypes.DistributionOperation.toString,
|
||||||
dist: DistributionTypes.genericDist,
|
dist: DistributionTypes.genericDist,
|
||||||
|
~env: DistributionOperation.env,
|
||||||
) => {
|
) => {
|
||||||
FromDist(DistributionTypes.DistributionOperation.ToString(fnCall), dist)
|
FromDist(DistributionTypes.DistributionOperation.ToString(fnCall), dist)
|
||||||
->runGenericOperation
|
->DistributionOperation.run(~env)
|
||||||
->Some
|
->Some
|
||||||
}
|
}
|
||||||
|
|
||||||
let toBoolFn = (
|
let toBoolFn = (
|
||||||
fnCall: DistributionTypes.DistributionOperation.toBool,
|
fnCall: DistributionTypes.DistributionOperation.toBool,
|
||||||
dist: DistributionTypes.genericDist,
|
dist: DistributionTypes.genericDist,
|
||||||
|
~env: DistributionOperation.env,
|
||||||
) => {
|
) => {
|
||||||
FromDist(DistributionTypes.DistributionOperation.ToBool(fnCall), dist)
|
FromDist(DistributionTypes.DistributionOperation.ToBool(fnCall), dist)
|
||||||
->runGenericOperation
|
->DistributionOperation.run(~env)
|
||||||
->Some
|
->Some
|
||||||
}
|
}
|
||||||
|
|
||||||
let toDistFn = (fnCall: DistributionTypes.DistributionOperation.toDist, dist) => {
|
let toDistFn = (
|
||||||
|
fnCall: DistributionTypes.DistributionOperation.toDist,
|
||||||
|
dist,
|
||||||
|
~env: DistributionOperation.env,
|
||||||
|
) => {
|
||||||
FromDist(DistributionTypes.DistributionOperation.ToDist(fnCall), dist)
|
FromDist(DistributionTypes.DistributionOperation.ToDist(fnCall), dist)
|
||||||
->runGenericOperation
|
->DistributionOperation.run(~env)
|
||||||
->Some
|
->Some
|
||||||
}
|
}
|
||||||
|
|
||||||
let twoDiststoDistFn = (direction, arithmetic, dist1, dist2) => {
|
let twoDiststoDistFn = (direction, arithmetic, dist1, dist2, ~env: DistributionOperation.env) => {
|
||||||
FromDist(
|
FromDist(
|
||||||
DistributionTypes.DistributionOperation.ToDistCombination(
|
DistributionTypes.DistributionOperation.ToDistCombination(
|
||||||
direction,
|
direction,
|
||||||
|
@ -77,7 +77,7 @@ module Helpers = {
|
||||||
#Dist(dist2),
|
#Dist(dist2),
|
||||||
),
|
),
|
||||||
dist1,
|
dist1,
|
||||||
)->runGenericOperation
|
)->DistributionOperation.run(~env)
|
||||||
}
|
}
|
||||||
|
|
||||||
let parseNumber = (args: expressionValue): Belt.Result.t<float, string> =>
|
let parseNumber = (args: expressionValue): Belt.Result.t<float, string> =>
|
||||||
|
@ -104,33 +104,38 @@ module Helpers = {
|
||||||
let mixtureWithGivenWeights = (
|
let mixtureWithGivenWeights = (
|
||||||
distributions: array<DistributionTypes.genericDist>,
|
distributions: array<DistributionTypes.genericDist>,
|
||||||
weights: array<float>,
|
weights: array<float>,
|
||||||
|
~env: DistributionOperation.env,
|
||||||
): DistributionOperation.outputType =>
|
): DistributionOperation.outputType =>
|
||||||
E.A.length(distributions) == E.A.length(weights)
|
E.A.length(distributions) == E.A.length(weights)
|
||||||
? Mixture(Belt.Array.zip(distributions, weights))->runGenericOperation
|
? Mixture(Belt.Array.zip(distributions, weights))->DistributionOperation.run(~env)
|
||||||
: GenDistError(
|
: GenDistError(
|
||||||
ArgumentError("Error, mixture call has different number of distributions and weights"),
|
ArgumentError("Error, mixture call has different number of distributions and weights"),
|
||||||
)
|
)
|
||||||
|
|
||||||
let mixtureWithDefaultWeights = (
|
let mixtureWithDefaultWeights = (
|
||||||
distributions: array<DistributionTypes.genericDist>,
|
distributions: array<DistributionTypes.genericDist>,
|
||||||
|
~env: DistributionOperation.env,
|
||||||
): DistributionOperation.outputType => {
|
): DistributionOperation.outputType => {
|
||||||
let length = E.A.length(distributions)
|
let length = E.A.length(distributions)
|
||||||
let weights = Belt.Array.make(length, 1.0 /. Belt.Int.toFloat(length))
|
let weights = Belt.Array.make(length, 1.0 /. Belt.Int.toFloat(length))
|
||||||
mixtureWithGivenWeights(distributions, weights)
|
mixtureWithGivenWeights(distributions, weights, ~env)
|
||||||
}
|
}
|
||||||
|
|
||||||
let mixture = (args: array<expressionValue>): DistributionOperation.outputType => {
|
let mixture = (
|
||||||
|
args: array<expressionValue>,
|
||||||
|
~env: DistributionOperation.env,
|
||||||
|
): DistributionOperation.outputType => {
|
||||||
let error = (err: string): DistributionOperation.outputType =>
|
let error = (err: string): DistributionOperation.outputType =>
|
||||||
err->DistributionTypes.ArgumentError->GenDistError
|
err->DistributionTypes.ArgumentError->GenDistError
|
||||||
switch args {
|
switch args {
|
||||||
| [EvArray(distributions)] =>
|
| [EvArray(distributions)] =>
|
||||||
switch parseDistributionArray(distributions) {
|
switch parseDistributionArray(distributions) {
|
||||||
| Ok(distrs) => mixtureWithDefaultWeights(distrs)
|
| Ok(distrs) => mixtureWithDefaultWeights(distrs, ~env)
|
||||||
| Error(err) => error(err)
|
| Error(err) => error(err)
|
||||||
}
|
}
|
||||||
| [EvArray(distributions), EvArray(weights)] =>
|
| [EvArray(distributions), EvArray(weights)] =>
|
||||||
switch (parseDistributionArray(distributions), parseNumberArray(weights)) {
|
switch (parseDistributionArray(distributions), parseNumberArray(weights)) {
|
||||||
| (Ok(distrs), Ok(wghts)) => mixtureWithGivenWeights(distrs, wghts)
|
| (Ok(distrs), Ok(wghts)) => mixtureWithGivenWeights(distrs, wghts, ~env)
|
||||||
| (Error(err), Ok(_)) => error(err)
|
| (Error(err), Ok(_)) => error(err)
|
||||||
| (Ok(_), Error(err)) => error(err)
|
| (Ok(_), Error(err)) => error(err)
|
||||||
| (Error(err1), Error(err2)) => error(`${err1}|${err2}`)
|
| (Error(err1), Error(err2)) => error(`${err1}|${err2}`)
|
||||||
|
@ -143,14 +148,14 @@ module Helpers = {
|
||||||
Belt.Array.slice(args, ~offset=0, ~len=E.A.length(args) - 1),
|
Belt.Array.slice(args, ~offset=0, ~len=E.A.length(args) - 1),
|
||||||
)
|
)
|
||||||
switch E.R.merge(distributions, weights) {
|
switch E.R.merge(distributions, weights) {
|
||||||
| Ok(d, w) => mixtureWithGivenWeights(d, w)
|
| Ok(d, w) => mixtureWithGivenWeights(d, w, ~env)
|
||||||
| Error(err) => error(err)
|
| Error(err) => error(err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
| Some(EvNumber(_))
|
| Some(EvNumber(_))
|
||||||
| Some(EvDistribution(_)) =>
|
| Some(EvDistribution(_)) =>
|
||||||
switch parseDistributionArray(args) {
|
switch parseDistributionArray(args) {
|
||||||
| Ok(distributions) => mixtureWithDefaultWeights(distributions)
|
| Ok(distributions) => mixtureWithDefaultWeights(distributions, ~env)
|
||||||
| Error(err) => error(err)
|
| Error(err) => error(err)
|
||||||
}
|
}
|
||||||
| _ => error("Last argument of mx must be array or distribution")
|
| _ => error("Last argument of mx must be array or distribution")
|
||||||
|
@ -193,9 +198,10 @@ module SymbolicConstructors = {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
let dispatchToGenericOutput = (call: ExpressionValue.functionCall, _environment): option<
|
let dispatchToGenericOutput = (
|
||||||
DistributionOperation.outputType,
|
call: ExpressionValue.functionCall,
|
||||||
> => {
|
env: DistributionOperation.env,
|
||||||
|
): option<DistributionOperation.outputType> => {
|
||||||
let (fnName, args) = call
|
let (fnName, args) = call
|
||||||
switch (fnName, args) {
|
switch (fnName, args) {
|
||||||
| ("exponential" as fnName, [EvNumber(f)]) =>
|
| ("exponential" as fnName, [EvNumber(f)]) =>
|
||||||
|
@ -215,16 +221,16 @@ let dispatchToGenericOutput = (call: ExpressionValue.functionCall, _environment)
|
||||||
SymbolicConstructors.threeFloat(fnName)
|
SymbolicConstructors.threeFloat(fnName)
|
||||||
->E.R.bind(r => r(f1, f2, f3))
|
->E.R.bind(r => r(f1, f2, f3))
|
||||||
->SymbolicConstructors.symbolicResultToOutput
|
->SymbolicConstructors.symbolicResultToOutput
|
||||||
| ("sample", [EvDistribution(dist)]) => Helpers.toFloatFn(#Sample, dist)
|
| ("sample", [EvDistribution(dist)]) => Helpers.toFloatFn(#Sample, dist, ~env)
|
||||||
| ("sampleN", [EvDistribution(dist), EvNumber(n)]) => Some(
|
| ("sampleN", [EvDistribution(dist), EvNumber(n)]) => Some(
|
||||||
FloatArray(GenericDist.sampleN(dist, Belt.Int.fromFloat(n))),
|
FloatArray(GenericDist.sampleN(dist, Belt.Int.fromFloat(n))),
|
||||||
)
|
)
|
||||||
| ("mean", [EvDistribution(dist)]) => Helpers.toFloatFn(#Mean, dist)
|
| ("mean", [EvDistribution(dist)]) => Helpers.toFloatFn(#Mean, dist, ~env)
|
||||||
| ("integralSum", [EvDistribution(dist)]) => Helpers.toFloatFn(#IntegralSum, dist)
|
| ("integralSum", [EvDistribution(dist)]) => Helpers.toFloatFn(#IntegralSum, dist, ~env)
|
||||||
| ("toString", [EvDistribution(dist)]) => Helpers.toStringFn(ToString, dist)
|
| ("toString", [EvDistribution(dist)]) => Helpers.toStringFn(ToString, dist, ~env)
|
||||||
| ("toSparkline", [EvDistribution(dist)]) => Helpers.toStringFn(ToSparkline(20), dist)
|
| ("toSparkline", [EvDistribution(dist)]) => Helpers.toStringFn(ToSparkline(20), dist, ~env)
|
||||||
| ("toSparkline", [EvDistribution(dist), EvNumber(n)]) =>
|
| ("toSparkline", [EvDistribution(dist), EvNumber(n)]) =>
|
||||||
Helpers.toStringFn(ToSparkline(Belt.Float.toInt(n)), dist)
|
Helpers.toStringFn(ToSparkline(Belt.Float.toInt(n)), dist, ~env)
|
||||||
| ("exp", [EvDistribution(a)]) =>
|
| ("exp", [EvDistribution(a)]) =>
|
||||||
// https://mathjs.org/docs/reference/functions/exp.html
|
// https://mathjs.org/docs/reference/functions/exp.html
|
||||||
Helpers.twoDiststoDistFn(
|
Helpers.twoDiststoDistFn(
|
||||||
|
@ -232,60 +238,74 @@ let dispatchToGenericOutput = (call: ExpressionValue.functionCall, _environment)
|
||||||
"pow",
|
"pow",
|
||||||
GenericDist.fromFloat(MagicNumbers.Math.e),
|
GenericDist.fromFloat(MagicNumbers.Math.e),
|
||||||
a,
|
a,
|
||||||
|
~env,
|
||||||
)->Some
|
)->Some
|
||||||
| ("normalize", [EvDistribution(dist)]) => Helpers.toDistFn(Normalize, dist)
|
| ("normalize", [EvDistribution(dist)]) => Helpers.toDistFn(Normalize, dist, ~env)
|
||||||
| ("klDivergence", [EvDistribution(a), EvDistribution(b)]) =>
|
| ("klDivergence", [EvDistribution(a), EvDistribution(b)]) =>
|
||||||
Some(runGenericOperation(FromDist(ToScore(KLDivergence(b)), a)))
|
Some(DistributionOperation.run(FromDist(ToScore(KLDivergence(b)), a), ~env))
|
||||||
| ("isNormalized", [EvDistribution(dist)]) => Helpers.toBoolFn(IsNormalized, dist)
|
| ("isNormalized", [EvDistribution(dist)]) => Helpers.toBoolFn(IsNormalized, dist, ~env)
|
||||||
| ("toPointSet", [EvDistribution(dist)]) => Helpers.toDistFn(ToPointSet, dist)
|
| ("toPointSet", [EvDistribution(dist)]) => Helpers.toDistFn(ToPointSet, dist, ~env)
|
||||||
| ("scaleLog", [EvDistribution(dist)]) =>
|
| ("scaleLog", [EvDistribution(dist)]) =>
|
||||||
Helpers.toDistFn(Scale(#Logarithm, MagicNumbers.Math.e), dist)
|
Helpers.toDistFn(Scale(#Logarithm, MagicNumbers.Math.e), dist, ~env)
|
||||||
| ("scaleLog10", [EvDistribution(dist)]) => Helpers.toDistFn(Scale(#Logarithm, 10.0), dist)
|
| ("scaleLog10", [EvDistribution(dist)]) => Helpers.toDistFn(Scale(#Logarithm, 10.0), dist, ~env)
|
||||||
| ("scaleLog", [EvDistribution(dist), EvNumber(float)]) =>
|
| ("scaleLog", [EvDistribution(dist), EvNumber(float)]) =>
|
||||||
Helpers.toDistFn(Scale(#Logarithm, float), dist)
|
Helpers.toDistFn(Scale(#Logarithm, float), dist, ~env)
|
||||||
| ("scaleLogWithThreshold", [EvDistribution(dist), EvNumber(base), EvNumber(eps)]) =>
|
| ("scaleLogWithThreshold", [EvDistribution(dist), EvNumber(base), EvNumber(eps)]) =>
|
||||||
Helpers.toDistFn(Scale(#LogarithmWithThreshold(eps), base), dist)
|
Helpers.toDistFn(Scale(#LogarithmWithThreshold(eps), base), dist, ~env)
|
||||||
| ("scalePow", [EvDistribution(dist), EvNumber(float)]) =>
|
| ("scalePow", [EvDistribution(dist), EvNumber(float)]) =>
|
||||||
Helpers.toDistFn(Scale(#Power, float), dist)
|
Helpers.toDistFn(Scale(#Power, float), dist, ~env)
|
||||||
| ("scaleExp", [EvDistribution(dist)]) =>
|
| ("scaleExp", [EvDistribution(dist)]) =>
|
||||||
Helpers.toDistFn(Scale(#Power, MagicNumbers.Math.e), dist)
|
Helpers.toDistFn(Scale(#Power, MagicNumbers.Math.e), dist, ~env)
|
||||||
| ("cdf", [EvDistribution(dist), EvNumber(float)]) => Helpers.toFloatFn(#Cdf(float), dist)
|
| ("cdf", [EvDistribution(dist), EvNumber(float)]) => Helpers.toFloatFn(#Cdf(float), dist, ~env)
|
||||||
| ("pdf", [EvDistribution(dist), EvNumber(float)]) => Helpers.toFloatFn(#Pdf(float), dist)
|
| ("pdf", [EvDistribution(dist), EvNumber(float)]) => Helpers.toFloatFn(#Pdf(float), dist, ~env)
|
||||||
| ("inv", [EvDistribution(dist), EvNumber(float)]) => Helpers.toFloatFn(#Inv(float), dist)
|
| ("inv", [EvDistribution(dist), EvNumber(float)]) => Helpers.toFloatFn(#Inv(float), dist, ~env)
|
||||||
| ("toSampleSet", [EvDistribution(dist), EvNumber(float)]) =>
|
| ("toSampleSet", [EvDistribution(dist), EvNumber(float)]) =>
|
||||||
Helpers.toDistFn(ToSampleSet(Belt.Int.fromFloat(float)), dist)
|
Helpers.toDistFn(ToSampleSet(Belt.Int.fromFloat(float)), dist, ~env)
|
||||||
| ("toSampleSet", [EvDistribution(dist)]) =>
|
| ("toSampleSet", [EvDistribution(dist)]) =>
|
||||||
Helpers.toDistFn(ToSampleSet(MagicNumbers.Environment.defaultSampleCount), dist)
|
Helpers.toDistFn(ToSampleSet(env.sampleCount), dist, ~env)
|
||||||
| ("fromSamples", [EvArray(inputArray)]) => {
|
| ("fromSamples", [EvArray(inputArray)]) => {
|
||||||
let _wrapInputErrors = x => SampleSetDist.NonNumericInput(x)
|
let _wrapInputErrors = x => SampleSetDist.NonNumericInput(x)
|
||||||
let parsedArray = Helpers.parseNumberArray(inputArray)->E.R2.errMap(_wrapInputErrors)
|
let parsedArray = Helpers.parseNumberArray(inputArray)->E.R2.errMap(_wrapInputErrors)
|
||||||
switch parsedArray {
|
switch parsedArray {
|
||||||
| Ok(array) => runGenericOperation(FromSamples(array))
|
| Ok(array) => DistributionOperation.run(FromSamples(array), ~env)
|
||||||
| Error(e) => GenDistError(SampleSetError(e))
|
| Error(e) => GenDistError(SampleSetError(e))
|
||||||
}->Some
|
}->Some
|
||||||
}
|
}
|
||||||
| ("inspect", [EvDistribution(dist)]) => Helpers.toDistFn(Inspect, dist)
|
| ("inspect", [EvDistribution(dist)]) => Helpers.toDistFn(Inspect, dist, ~env)
|
||||||
| ("truncateLeft", [EvDistribution(dist), EvNumber(float)]) =>
|
| ("truncateLeft", [EvDistribution(dist), EvNumber(float)]) =>
|
||||||
Helpers.toDistFn(Truncate(Some(float), None), dist)
|
Helpers.toDistFn(Truncate(Some(float), None), dist, ~env)
|
||||||
| ("truncateRight", [EvDistribution(dist), EvNumber(float)]) =>
|
| ("truncateRight", [EvDistribution(dist), EvNumber(float)]) =>
|
||||||
Helpers.toDistFn(Truncate(None, Some(float)), dist)
|
Helpers.toDistFn(Truncate(None, Some(float)), dist, ~env)
|
||||||
| ("truncate", [EvDistribution(dist), EvNumber(float1), EvNumber(float2)]) =>
|
| ("truncate", [EvDistribution(dist), EvNumber(float1), EvNumber(float2)]) =>
|
||||||
Helpers.toDistFn(Truncate(Some(float1), Some(float2)), dist)
|
Helpers.toDistFn(Truncate(Some(float1), Some(float2)), dist, ~env)
|
||||||
| ("mx" | "mixture", args) => Helpers.mixture(args)->Some
|
| ("mx" | "mixture", args) => Helpers.mixture(args, ~env)->Some
|
||||||
| ("log", [EvDistribution(a)]) =>
|
| ("log", [EvDistribution(a)]) =>
|
||||||
Helpers.twoDiststoDistFn(
|
Helpers.twoDiststoDistFn(
|
||||||
Algebraic(AsDefault),
|
Algebraic(AsDefault),
|
||||||
"log",
|
"log",
|
||||||
a,
|
a,
|
||||||
GenericDist.fromFloat(MagicNumbers.Math.e),
|
GenericDist.fromFloat(MagicNumbers.Math.e),
|
||||||
|
~env,
|
||||||
)->Some
|
)->Some
|
||||||
| ("log10", [EvDistribution(a)]) =>
|
| ("log10", [EvDistribution(a)]) =>
|
||||||
Helpers.twoDiststoDistFn(Algebraic(AsDefault), "log", a, GenericDist.fromFloat(10.0))->Some
|
Helpers.twoDiststoDistFn(
|
||||||
|
Algebraic(AsDefault),
|
||||||
|
"log",
|
||||||
|
a,
|
||||||
|
GenericDist.fromFloat(10.0),
|
||||||
|
~env,
|
||||||
|
)->Some
|
||||||
| ("unaryMinus", [EvDistribution(a)]) =>
|
| ("unaryMinus", [EvDistribution(a)]) =>
|
||||||
Helpers.twoDiststoDistFn(Algebraic(AsDefault), "multiply", a, GenericDist.fromFloat(-1.0))->Some
|
Helpers.twoDiststoDistFn(
|
||||||
|
Algebraic(AsDefault),
|
||||||
|
"multiply",
|
||||||
|
a,
|
||||||
|
GenericDist.fromFloat(-1.0),
|
||||||
|
~env,
|
||||||
|
)->Some
|
||||||
| (("add" | "multiply" | "subtract" | "divide" | "pow" | "log") as arithmetic, [_, _] as args) =>
|
| (("add" | "multiply" | "subtract" | "divide" | "pow" | "log") as arithmetic, [_, _] as args) =>
|
||||||
Helpers.catchAndConvertTwoArgsToDists(args)->E.O2.fmap(((fst, snd)) =>
|
Helpers.catchAndConvertTwoArgsToDists(args)->E.O2.fmap(((fst, snd)) =>
|
||||||
Helpers.twoDiststoDistFn(Algebraic(AsDefault), arithmetic, fst, snd)
|
Helpers.twoDiststoDistFn(Algebraic(AsDefault), arithmetic, fst, snd, ~env)
|
||||||
)
|
)
|
||||||
| (
|
| (
|
||||||
("dotAdd"
|
("dotAdd"
|
||||||
|
@ -296,7 +316,7 @@ let dispatchToGenericOutput = (call: ExpressionValue.functionCall, _environment)
|
||||||
[_, _] as args,
|
[_, _] as args,
|
||||||
) =>
|
) =>
|
||||||
Helpers.catchAndConvertTwoArgsToDists(args)->E.O2.fmap(((fst, snd)) =>
|
Helpers.catchAndConvertTwoArgsToDists(args)->E.O2.fmap(((fst, snd)) =>
|
||||||
Helpers.twoDiststoDistFn(Pointwise, arithmetic, fst, snd)
|
Helpers.twoDiststoDistFn(Pointwise, arithmetic, fst, snd, ~env)
|
||||||
)
|
)
|
||||||
| ("dotExp", [EvDistribution(a)]) =>
|
| ("dotExp", [EvDistribution(a)]) =>
|
||||||
Helpers.twoDiststoDistFn(
|
Helpers.twoDiststoDistFn(
|
||||||
|
@ -304,6 +324,7 @@ let dispatchToGenericOutput = (call: ExpressionValue.functionCall, _environment)
|
||||||
"dotPow",
|
"dotPow",
|
||||||
GenericDist.fromFloat(MagicNumbers.Math.e),
|
GenericDist.fromFloat(MagicNumbers.Math.e),
|
||||||
a,
|
a,
|
||||||
|
~env,
|
||||||
)->Some
|
)->Some
|
||||||
| _ => None
|
| _ => None
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,4 +1,3 @@
|
||||||
let defaultEnv: DistributionOperation.env
|
|
||||||
let dispatch: (
|
let dispatch: (
|
||||||
ReducerInterface_ExpressionValue.functionCall,
|
ReducerInterface_ExpressionValue.functionCall,
|
||||||
ReducerInterface_ExpressionValue.environment,
|
ReducerInterface_ExpressionValue.environment,
|
||||||
|
|
|
@ -77,7 +77,7 @@ let distributionErrorToString = DistributionTypes.Error.toString
|
||||||
type lambdaValue = ReducerInterface_ExpressionValue.lambdaValue
|
type lambdaValue = ReducerInterface_ExpressionValue.lambdaValue
|
||||||
|
|
||||||
@genType
|
@genType
|
||||||
let defaultSamplingEnv = ReducerInterface_GenericDistribution.defaultEnv
|
let defaultSamplingEnv = DistributionOperation.defaultEnv
|
||||||
|
|
||||||
@genType
|
@genType
|
||||||
type environment = ReducerInterface_ExpressionValue.environment
|
type environment = ReducerInterface_ExpressionValue.environment
|
||||||
|
|
|
@ -453,6 +453,44 @@ module PointwiseCombination = {
|
||||||
T.filterOkYs(newXs, newYs)->Ok
|
T.filterOkYs(newXs, newYs)->Ok
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* *Dead code*: Nuño wrote this function to try to increase precision, but it didn't work.
|
||||||
|
If another traveler comes through with a similar idea, we hope this implementation will help them.
|
||||||
|
By "enrich" we mean to increase granularity.
|
||||||
|
*/
|
||||||
|
let enrichXyShape = (t: T.t): T.t => {
|
||||||
|
let defaultEnrichmentFactor = 10
|
||||||
|
let length = E.A.length(t.xs)
|
||||||
|
let points =
|
||||||
|
length < MagicNumbers.Environment.defaultXYPointLength
|
||||||
|
? defaultEnrichmentFactor * MagicNumbers.Environment.defaultXYPointLength / length
|
||||||
|
: defaultEnrichmentFactor
|
||||||
|
|
||||||
|
let getInBetween = (x1: float, x2: float): array<float> => {
|
||||||
|
if abs_float(x1 -. x2) < 2.0 *. MagicNumbers.Epsilon.seven {
|
||||||
|
[x1]
|
||||||
|
} else {
|
||||||
|
let newPointsArray = Belt.Array.makeBy(points - 1, i => i)
|
||||||
|
// don't repeat the x2 point, it will be gotten in the next iteration.
|
||||||
|
let result = Js.Array.mapi((pos, i) =>
|
||||||
|
if i == 0 {
|
||||||
|
x1
|
||||||
|
} else {
|
||||||
|
let points' = Belt.Float.fromInt(points)
|
||||||
|
let pos' = Belt.Float.fromInt(pos)
|
||||||
|
x1 *. (points' -. pos') /. points' +. x2 *. pos' /. points'
|
||||||
|
}
|
||||||
|
, newPointsArray)
|
||||||
|
result
|
||||||
|
}
|
||||||
|
}
|
||||||
|
let newXsUnflattened = Js.Array.mapi(
|
||||||
|
(x, i) => i < length - 2 ? getInBetween(x, t.xs[i + 1]) : [x],
|
||||||
|
t.xs,
|
||||||
|
)
|
||||||
|
let newXs = Belt.Array.concatMany(newXsUnflattened)
|
||||||
|
let newYs = E.A.fmap(x => XtoY.linear(x, t), newXs)
|
||||||
|
{xs: newXs, ys: newYs}
|
||||||
|
}
|
||||||
// This function is used for klDivergence
|
// This function is used for klDivergence
|
||||||
let combineAlongSupportOfSecondArgument: (
|
let combineAlongSupportOfSecondArgument: (
|
||||||
(float, float) => result<float, Operation.Error.t>,
|
(float, float) => result<float, Operation.Error.t>,
|
||||||
|
|
16
yarn.lock
16
yarn.lock
|
@ -3726,10 +3726,10 @@
|
||||||
"@testing-library/dom" "^8.5.0"
|
"@testing-library/dom" "^8.5.0"
|
||||||
"@types/react-dom" "^18.0.0"
|
"@types/react-dom" "^18.0.0"
|
||||||
|
|
||||||
"@testing-library/user-event@^14.1.1":
|
"@testing-library/user-event@^14.2.0":
|
||||||
version "14.1.1"
|
version "14.2.0"
|
||||||
resolved "https://registry.yarnpkg.com/@testing-library/user-event/-/user-event-14.1.1.tgz#e1ff6118896e4b22af31e5ea2f9da956adde23d8"
|
resolved "https://registry.yarnpkg.com/@testing-library/user-event/-/user-event-14.2.0.tgz#8293560f8f80a00383d6c755ec3e0b918acb1683"
|
||||||
integrity sha512-XrjH/iEUqNl9lF2HX9YhPNV7Amntkcnpw0Bo1KkRzowNDcgSN9i0nm4Q8Oi5wupgdfPaJNMAWa61A+voD6Kmwg==
|
integrity sha512-+hIlG4nJS6ivZrKnOP7OGsDu9Fxmryj9vCl8x0ZINtTJcCHs2zLsYif5GzuRiBF2ck5GZG2aQr7Msg+EHlnYVQ==
|
||||||
|
|
||||||
"@tootallnate/once@1":
|
"@tootallnate/once@1":
|
||||||
version "1.1.2"
|
version "1.1.2"
|
||||||
|
@ -4038,10 +4038,10 @@
|
||||||
"@types/node" "*"
|
"@types/node" "*"
|
||||||
form-data "^3.0.0"
|
form-data "^3.0.0"
|
||||||
|
|
||||||
"@types/node@*", "@types/node@^17.0.31", "@types/node@^17.0.5":
|
"@types/node@*", "@types/node@^17.0.32", "@types/node@^17.0.5":
|
||||||
version "17.0.31"
|
version "17.0.32"
|
||||||
resolved "https://registry.yarnpkg.com/@types/node/-/node-17.0.31.tgz#a5bb84ecfa27eec5e1c802c6bbf8139bdb163a5d"
|
resolved "https://registry.yarnpkg.com/@types/node/-/node-17.0.32.tgz#51d59d7a90ef2d0ae961791e0900cad2393a0149"
|
||||||
integrity sha512-AR0x5HbXGqkEx9CadRH3EBYx/VkiUgZIhP4wvPn/+5KIsgpNoyFaRlVe0Zlx9gRtg8fA06a9tskE2MSN7TcG4Q==
|
integrity sha512-eAIcfAvhf/BkHcf4pkLJ7ECpBAhh9kcxRBpip9cTiO+hf+aJrsxYxBeS6OXvOd9WqNAJmavXVpZvY1rBjNsXmw==
|
||||||
|
|
||||||
"@types/node@^14.0.10":
|
"@types/node@^14.0.10":
|
||||||
version "14.18.16"
|
version "14.18.16"
|
||||||
|
|
Loading…
Reference in New Issue
Block a user