Merged with master
This commit is contained in:
commit
47d7ef49cf
|
@ -28,10 +28,10 @@
|
|||
"@storybook/react": "^6.4.22",
|
||||
"@testing-library/jest-dom": "^5.16.4",
|
||||
"@testing-library/react": "^13.2.0",
|
||||
"@testing-library/user-event": "^14.1.1",
|
||||
"@testing-library/user-event": "^14.2.0",
|
||||
"@types/jest": "^27.5.0",
|
||||
"@types/lodash": "^4.14.182",
|
||||
"@types/node": "^17.0.31",
|
||||
"@types/node": "^17.0.32",
|
||||
"@types/react": "^18.0.9",
|
||||
"@types/react-dom": "^18.0.2",
|
||||
"@types/styled-components": "^5.1.24",
|
||||
|
|
|
@ -76,28 +76,32 @@ export const FunctionChart: React.FC<FunctionChartProps> = ({
|
|||
chartSettings.count
|
||||
);
|
||||
type point = { x: number; value: result<Distribution, string> };
|
||||
let valueData: point[] = data1.map((x) => {
|
||||
let result = runForeign(fn, [x], environment);
|
||||
if (result.tag === "Ok") {
|
||||
if (result.value.tag == "distribution") {
|
||||
return { x, value: { tag: "Ok", value: result.value.value } };
|
||||
} else {
|
||||
return {
|
||||
x,
|
||||
value: {
|
||||
tag: "Error",
|
||||
value:
|
||||
"Cannot currently render functions that don't return distributions",
|
||||
},
|
||||
};
|
||||
}
|
||||
} else {
|
||||
return {
|
||||
x,
|
||||
value: { tag: "Error", value: errorValueToString(result.value) },
|
||||
};
|
||||
}
|
||||
});
|
||||
let valueData: point[] = React.useMemo(
|
||||
() =>
|
||||
data1.map((x) => {
|
||||
let result = runForeign(fn, [x], environment);
|
||||
if (result.tag === "Ok") {
|
||||
if (result.value.tag == "distribution") {
|
||||
return { x, value: { tag: "Ok", value: result.value.value } };
|
||||
} else {
|
||||
return {
|
||||
x,
|
||||
value: {
|
||||
tag: "Error",
|
||||
value:
|
||||
"Cannot currently render functions that don't return distributions",
|
||||
},
|
||||
};
|
||||
}
|
||||
} else {
|
||||
return {
|
||||
x,
|
||||
value: { tag: "Error", value: errorValueToString(result.value) },
|
||||
};
|
||||
}
|
||||
}),
|
||||
[environment, fn]
|
||||
);
|
||||
|
||||
let initialPartition: [
|
||||
{ x: number; value: Distribution }[],
|
||||
|
@ -141,10 +145,10 @@ export const FunctionChart: React.FC<FunctionChartProps> = ({
|
|||
/>
|
||||
{showChart}
|
||||
{_.entries(groupedErrors).map(([errorName, errorPoints]) => (
|
||||
<ErrorBox heading={errorName}>
|
||||
<ErrorBox key={errorName} heading={errorName}>
|
||||
Values:{" "}
|
||||
{errorPoints
|
||||
.map((r) => <NumberShower number={r.x} />)
|
||||
.map((r, i) => <NumberShower key={i} number={r.x} />)
|
||||
.reduce((a, b) => (
|
||||
<>
|
||||
{a}, {b}
|
||||
|
|
|
@ -148,8 +148,9 @@ const SquiggleItem: React.FC<SquiggleItemProps> = ({
|
|||
case "array":
|
||||
return (
|
||||
<VariableBox heading="Array" showTypes={showTypes}>
|
||||
{expression.value.map((r) => (
|
||||
{expression.value.map((r, i) => (
|
||||
<SquiggleItem
|
||||
key={i}
|
||||
expression={r}
|
||||
width={width !== undefined ? width - 20 : width}
|
||||
height={50}
|
||||
|
@ -166,7 +167,7 @@ const SquiggleItem: React.FC<SquiggleItemProps> = ({
|
|||
return (
|
||||
<VariableBox heading="Record" showTypes={showTypes}>
|
||||
{Object.entries(expression.value).map(([key, r]) => (
|
||||
<>
|
||||
<div key={key}>
|
||||
<RecordKeyHeader>{key}</RecordKeyHeader>
|
||||
<SquiggleItem
|
||||
expression={r}
|
||||
|
@ -178,14 +179,14 @@ const SquiggleItem: React.FC<SquiggleItemProps> = ({
|
|||
chartSettings={chartSettings}
|
||||
environment={environment}
|
||||
/>
|
||||
</>
|
||||
</div>
|
||||
))}
|
||||
</VariableBox>
|
||||
);
|
||||
case "arraystring":
|
||||
return (
|
||||
<VariableBox heading="Array String" showTypes={showTypes}>
|
||||
{expression.value.map((r) => `"${r}"`)}
|
||||
{expression.value.map((r) => `"${r}"`).join(", ")}
|
||||
</VariableBox>
|
||||
);
|
||||
case "lambda":
|
||||
|
|
|
@ -11,6 +11,7 @@ let triangularDist: DistributionTypes.genericDist = Symbolic(
|
|||
)
|
||||
let exponentialDist: DistributionTypes.genericDist = Symbolic(#Exponential({rate: 2.0}))
|
||||
let uniformDist: DistributionTypes.genericDist = Symbolic(#Uniform({low: 9.0, high: 10.0}))
|
||||
let uniformDist2: DistributionTypes.genericDist = Symbolic(#Uniform({low: 8.0, high: 11.0}))
|
||||
let floatDist: DistributionTypes.genericDist = Symbolic(#Float(1e1))
|
||||
|
||||
exception KlFailed
|
||||
|
|
|
@ -3,9 +3,15 @@ open Expect
|
|||
open TestHelpers
|
||||
open GenericDist_Fixtures
|
||||
|
||||
// integral from low to high of 1 / (high - low) log(normal(mean, stdev)(x) / (1 / (high - low))) dx
|
||||
let klNormalUniform = (mean, stdev, low, high): float =>
|
||||
-.Js.Math.log((high -. low) /. Js.Math.sqrt(2.0 *. MagicNumbers.Math.pi *. stdev ** 2.0)) +.
|
||||
1.0 /.
|
||||
stdev ** 2.0 *.
|
||||
(mean ** 2.0 -. (high +. low) *. mean +. (low ** 2.0 +. high *. low +. high ** 2.0) /. 3.0)
|
||||
|
||||
describe("klDivergence: continuous -> continuous -> float", () => {
|
||||
let klDivergence = DistributionOperation.Constructors.klDivergence(~env)
|
||||
exception KlFailed
|
||||
|
||||
let testUniform = (lowAnswer, highAnswer, lowPrediction, highPrediction) => {
|
||||
test("of two uniforms is equal to the analytic expression", () => {
|
||||
|
@ -59,6 +65,20 @@ describe("klDivergence: continuous -> continuous -> float", () => {
|
|||
}
|
||||
}
|
||||
})
|
||||
|
||||
test("of a normal and a uniform is equal to the formula", () => {
|
||||
let prediction = normalDist10
|
||||
let answer = uniformDist
|
||||
let kl = klDivergence(prediction, answer)
|
||||
let analyticalKl = klNormalUniform(10.0, 2.0, 9.0, 10.0)
|
||||
switch kl {
|
||||
| Ok(kl') => kl'->expect->toBeSoCloseTo(analyticalKl, ~digits=1)
|
||||
| Error(err) => {
|
||||
Js.Console.log(DistributionTypes.Error.toString(err))
|
||||
raise(KlFailed)
|
||||
}
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
describe("klDivergence: discrete -> discrete -> float", () => {
|
||||
|
@ -96,6 +116,64 @@ describe("klDivergence: discrete -> discrete -> float", () => {
|
|||
})
|
||||
})
|
||||
|
||||
describe("klDivergence: mixed -> mixed -> float", () => {
|
||||
let klDivergence = DistributionOperation.Constructors.klDivergence(~env)
|
||||
let mixture' = a => DistributionTypes.DistributionOperation.Mixture(a)
|
||||
let mixture = a => {
|
||||
let dist' = a->mixture'->run
|
||||
switch dist' {
|
||||
| Dist(dist) => dist
|
||||
| _ => raise(MixtureFailed)
|
||||
}
|
||||
}
|
||||
let a = [(point1, 1.0), (uniformDist, 1.0)]->mixture
|
||||
let b = [(point1, 1.0), (floatDist, 1.0), (normalDist10, 1.0)]->mixture
|
||||
let c = [(point1, 1.0), (point2, 1.0), (point3, 1.0), (uniformDist, 1.0)]->mixture
|
||||
let d =
|
||||
[(point1, 1.0), (point2, 1.0), (point3, 1.0), (floatDist, 1.0), (uniformDist2, 1.0)]->mixture
|
||||
|
||||
test("finite klDivergence produces correct answer", () => {
|
||||
let prediction = b
|
||||
let answer = a
|
||||
let kl = klDivergence(prediction, answer)
|
||||
// high = 10; low = 9; mean = 10; stdev = 2
|
||||
let analyticalKlContinuousPart = klNormalUniform(10.0, 2.0, 9.0, 10.0) /. 2.0
|
||||
let analyticalKlDiscretePart = 1.0 /. 2.0 *. Js.Math.log(2.0 /. 1.0)
|
||||
switch kl {
|
||||
| Ok(kl') =>
|
||||
kl'->expect->toBeSoCloseTo(analyticalKlContinuousPart +. analyticalKlDiscretePart, ~digits=1)
|
||||
| Error(err) =>
|
||||
Js.Console.log(DistributionTypes.Error.toString(err))
|
||||
raise(KlFailed)
|
||||
}
|
||||
})
|
||||
test("returns infinity when infinite", () => {
|
||||
let prediction = a
|
||||
let answer = b
|
||||
let kl = klDivergence(prediction, answer)
|
||||
switch kl {
|
||||
| Ok(kl') => kl'->expect->toEqual(infinity)
|
||||
| Error(err) =>
|
||||
Js.Console.log(DistributionTypes.Error.toString(err))
|
||||
raise(KlFailed)
|
||||
}
|
||||
})
|
||||
test("finite klDivergence produces correct answer", () => {
|
||||
let prediction = d
|
||||
let answer = c
|
||||
let kl = klDivergence(prediction, answer)
|
||||
let analyticalKlContinuousPart = Js.Math.log((11.0 -. 8.0) /. (10.0 -. 9.0)) /. 4.0 // 4 = length of c' array
|
||||
let analyticalKlDiscretePart = 3.0 /. 4.0 *. Js.Math.log(4.0 /. 3.0)
|
||||
switch kl {
|
||||
| Ok(kl') =>
|
||||
kl'->expect->toBeSoCloseTo(analyticalKlContinuousPart +. analyticalKlDiscretePart, ~digits=1)
|
||||
| Error(err) =>
|
||||
Js.Console.log(DistributionTypes.Error.toString(err))
|
||||
raise(KlFailed)
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
describe("combineAlongSupportOfSecondArgument0", () => {
|
||||
// This tests the version of the function that we're NOT using. Haven't deleted the test in case we use the code later.
|
||||
test("test on two uniforms", _ => {
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
{
|
||||
"name": "@quri/squiggle-lang",
|
||||
"version": "0.2.8",
|
||||
"version": "0.2.9",
|
||||
"homepage": "https://squiggle-language.com",
|
||||
"license": "MIT",
|
||||
"scripts": {
|
||||
|
|
|
@ -302,10 +302,9 @@ module T = Dist({
|
|||
}
|
||||
|
||||
let klDivergence = (prediction: t, answer: t) => {
|
||||
Error(Operation.NotYetImplemented)
|
||||
// combinePointwise(PointSetDist_Scoring.KLDivergence.integrand, prediction, answer) |> E.R.fmap(
|
||||
// integralEndY,
|
||||
// )
|
||||
let klDiscretePart = Discrete.T.klDivergence(prediction.discrete, answer.discrete)
|
||||
let klContinuousPart = Continuous.T.klDivergence(prediction.continuous, answer.continuous)
|
||||
E.R.merge(klDiscretePart, klContinuousPart)->E.R2.fmap(t => fst(t) +. snd(t))
|
||||
}
|
||||
})
|
||||
|
||||
|
|
|
@ -200,6 +200,7 @@ module T = Dist({
|
|||
switch (t1, t2) {
|
||||
| (Continuous(t1), Continuous(t2)) => Continuous.T.klDivergence(t1, t2)
|
||||
| (Discrete(t1), Discrete(t2)) => Discrete.T.klDivergence(t1, t2)
|
||||
| (Mixed(t1), Mixed(t2)) => Mixed.T.klDivergence(t1, t2)
|
||||
| _ => Error(NotYetImplemented)
|
||||
}
|
||||
})
|
||||
|
|
|
@ -1,13 +1,6 @@
|
|||
module ExpressionValue = ReducerInterface_ExpressionValue
|
||||
type expressionValue = ReducerInterface_ExpressionValue.expressionValue
|
||||
|
||||
let defaultEnv: DistributionOperation.env = {
|
||||
sampleCount: MagicNumbers.Environment.defaultSampleCount,
|
||||
xyPointLength: MagicNumbers.Environment.defaultXYPointLength,
|
||||
}
|
||||
|
||||
let runGenericOperation = DistributionOperation.run(~env=defaultEnv)
|
||||
|
||||
module Helpers = {
|
||||
let arithmeticMap = r =>
|
||||
switch r {
|
||||
|
@ -39,37 +32,44 @@ module Helpers = {
|
|||
let toFloatFn = (
|
||||
fnCall: DistributionTypes.DistributionOperation.toFloat,
|
||||
dist: DistributionTypes.genericDist,
|
||||
~env: DistributionOperation.env,
|
||||
) => {
|
||||
FromDist(DistributionTypes.DistributionOperation.ToFloat(fnCall), dist)
|
||||
->runGenericOperation
|
||||
->DistributionOperation.run(~env)
|
||||
->Some
|
||||
}
|
||||
|
||||
let toStringFn = (
|
||||
fnCall: DistributionTypes.DistributionOperation.toString,
|
||||
dist: DistributionTypes.genericDist,
|
||||
~env: DistributionOperation.env,
|
||||
) => {
|
||||
FromDist(DistributionTypes.DistributionOperation.ToString(fnCall), dist)
|
||||
->runGenericOperation
|
||||
->DistributionOperation.run(~env)
|
||||
->Some
|
||||
}
|
||||
|
||||
let toBoolFn = (
|
||||
fnCall: DistributionTypes.DistributionOperation.toBool,
|
||||
dist: DistributionTypes.genericDist,
|
||||
~env: DistributionOperation.env,
|
||||
) => {
|
||||
FromDist(DistributionTypes.DistributionOperation.ToBool(fnCall), dist)
|
||||
->runGenericOperation
|
||||
->DistributionOperation.run(~env)
|
||||
->Some
|
||||
}
|
||||
|
||||
let toDistFn = (fnCall: DistributionTypes.DistributionOperation.toDist, dist) => {
|
||||
let toDistFn = (
|
||||
fnCall: DistributionTypes.DistributionOperation.toDist,
|
||||
dist,
|
||||
~env: DistributionOperation.env,
|
||||
) => {
|
||||
FromDist(DistributionTypes.DistributionOperation.ToDist(fnCall), dist)
|
||||
->runGenericOperation
|
||||
->DistributionOperation.run(~env)
|
||||
->Some
|
||||
}
|
||||
|
||||
let twoDiststoDistFn = (direction, arithmetic, dist1, dist2) => {
|
||||
let twoDiststoDistFn = (direction, arithmetic, dist1, dist2, ~env: DistributionOperation.env) => {
|
||||
FromDist(
|
||||
DistributionTypes.DistributionOperation.ToDistCombination(
|
||||
direction,
|
||||
|
@ -77,7 +77,7 @@ module Helpers = {
|
|||
#Dist(dist2),
|
||||
),
|
||||
dist1,
|
||||
)->runGenericOperation
|
||||
)->DistributionOperation.run(~env)
|
||||
}
|
||||
|
||||
let parseNumber = (args: expressionValue): Belt.Result.t<float, string> =>
|
||||
|
@ -104,33 +104,38 @@ module Helpers = {
|
|||
let mixtureWithGivenWeights = (
|
||||
distributions: array<DistributionTypes.genericDist>,
|
||||
weights: array<float>,
|
||||
~env: DistributionOperation.env,
|
||||
): DistributionOperation.outputType =>
|
||||
E.A.length(distributions) == E.A.length(weights)
|
||||
? Mixture(Belt.Array.zip(distributions, weights))->runGenericOperation
|
||||
? Mixture(Belt.Array.zip(distributions, weights))->DistributionOperation.run(~env)
|
||||
: GenDistError(
|
||||
ArgumentError("Error, mixture call has different number of distributions and weights"),
|
||||
)
|
||||
|
||||
let mixtureWithDefaultWeights = (
|
||||
distributions: array<DistributionTypes.genericDist>,
|
||||
~env: DistributionOperation.env,
|
||||
): DistributionOperation.outputType => {
|
||||
let length = E.A.length(distributions)
|
||||
let weights = Belt.Array.make(length, 1.0 /. Belt.Int.toFloat(length))
|
||||
mixtureWithGivenWeights(distributions, weights)
|
||||
mixtureWithGivenWeights(distributions, weights, ~env)
|
||||
}
|
||||
|
||||
let mixture = (args: array<expressionValue>): DistributionOperation.outputType => {
|
||||
let mixture = (
|
||||
args: array<expressionValue>,
|
||||
~env: DistributionOperation.env,
|
||||
): DistributionOperation.outputType => {
|
||||
let error = (err: string): DistributionOperation.outputType =>
|
||||
err->DistributionTypes.ArgumentError->GenDistError
|
||||
switch args {
|
||||
| [EvArray(distributions)] =>
|
||||
switch parseDistributionArray(distributions) {
|
||||
| Ok(distrs) => mixtureWithDefaultWeights(distrs)
|
||||
| Ok(distrs) => mixtureWithDefaultWeights(distrs, ~env)
|
||||
| Error(err) => error(err)
|
||||
}
|
||||
| [EvArray(distributions), EvArray(weights)] =>
|
||||
switch (parseDistributionArray(distributions), parseNumberArray(weights)) {
|
||||
| (Ok(distrs), Ok(wghts)) => mixtureWithGivenWeights(distrs, wghts)
|
||||
| (Ok(distrs), Ok(wghts)) => mixtureWithGivenWeights(distrs, wghts, ~env)
|
||||
| (Error(err), Ok(_)) => error(err)
|
||||
| (Ok(_), Error(err)) => error(err)
|
||||
| (Error(err1), Error(err2)) => error(`${err1}|${err2}`)
|
||||
|
@ -143,14 +148,14 @@ module Helpers = {
|
|||
Belt.Array.slice(args, ~offset=0, ~len=E.A.length(args) - 1),
|
||||
)
|
||||
switch E.R.merge(distributions, weights) {
|
||||
| Ok(d, w) => mixtureWithGivenWeights(d, w)
|
||||
| Ok(d, w) => mixtureWithGivenWeights(d, w, ~env)
|
||||
| Error(err) => error(err)
|
||||
}
|
||||
}
|
||||
| Some(EvNumber(_))
|
||||
| Some(EvDistribution(_)) =>
|
||||
switch parseDistributionArray(args) {
|
||||
| Ok(distributions) => mixtureWithDefaultWeights(distributions)
|
||||
| Ok(distributions) => mixtureWithDefaultWeights(distributions, ~env)
|
||||
| Error(err) => error(err)
|
||||
}
|
||||
| _ => error("Last argument of mx must be array or distribution")
|
||||
|
@ -193,9 +198,10 @@ module SymbolicConstructors = {
|
|||
}
|
||||
}
|
||||
|
||||
let dispatchToGenericOutput = (call: ExpressionValue.functionCall, _environment): option<
|
||||
DistributionOperation.outputType,
|
||||
> => {
|
||||
let dispatchToGenericOutput = (
|
||||
call: ExpressionValue.functionCall,
|
||||
env: DistributionOperation.env,
|
||||
): option<DistributionOperation.outputType> => {
|
||||
let (fnName, args) = call
|
||||
switch (fnName, args) {
|
||||
| ("exponential" as fnName, [EvNumber(f)]) =>
|
||||
|
@ -215,16 +221,16 @@ let dispatchToGenericOutput = (call: ExpressionValue.functionCall, _environment)
|
|||
SymbolicConstructors.threeFloat(fnName)
|
||||
->E.R.bind(r => r(f1, f2, f3))
|
||||
->SymbolicConstructors.symbolicResultToOutput
|
||||
| ("sample", [EvDistribution(dist)]) => Helpers.toFloatFn(#Sample, dist)
|
||||
| ("sample", [EvDistribution(dist)]) => Helpers.toFloatFn(#Sample, dist, ~env)
|
||||
| ("sampleN", [EvDistribution(dist), EvNumber(n)]) => Some(
|
||||
FloatArray(GenericDist.sampleN(dist, Belt.Int.fromFloat(n))),
|
||||
)
|
||||
| ("mean", [EvDistribution(dist)]) => Helpers.toFloatFn(#Mean, dist)
|
||||
| ("integralSum", [EvDistribution(dist)]) => Helpers.toFloatFn(#IntegralSum, dist)
|
||||
| ("toString", [EvDistribution(dist)]) => Helpers.toStringFn(ToString, dist)
|
||||
| ("toSparkline", [EvDistribution(dist)]) => Helpers.toStringFn(ToSparkline(20), dist)
|
||||
| ("mean", [EvDistribution(dist)]) => Helpers.toFloatFn(#Mean, dist, ~env)
|
||||
| ("integralSum", [EvDistribution(dist)]) => Helpers.toFloatFn(#IntegralSum, dist, ~env)
|
||||
| ("toString", [EvDistribution(dist)]) => Helpers.toStringFn(ToString, dist, ~env)
|
||||
| ("toSparkline", [EvDistribution(dist)]) => Helpers.toStringFn(ToSparkline(20), dist, ~env)
|
||||
| ("toSparkline", [EvDistribution(dist), EvNumber(n)]) =>
|
||||
Helpers.toStringFn(ToSparkline(Belt.Float.toInt(n)), dist)
|
||||
Helpers.toStringFn(ToSparkline(Belt.Float.toInt(n)), dist, ~env)
|
||||
| ("exp", [EvDistribution(a)]) =>
|
||||
// https://mathjs.org/docs/reference/functions/exp.html
|
||||
Helpers.twoDiststoDistFn(
|
||||
|
@ -232,60 +238,74 @@ let dispatchToGenericOutput = (call: ExpressionValue.functionCall, _environment)
|
|||
"pow",
|
||||
GenericDist.fromFloat(MagicNumbers.Math.e),
|
||||
a,
|
||||
~env,
|
||||
)->Some
|
||||
| ("normalize", [EvDistribution(dist)]) => Helpers.toDistFn(Normalize, dist)
|
||||
| ("normalize", [EvDistribution(dist)]) => Helpers.toDistFn(Normalize, dist, ~env)
|
||||
| ("klDivergence", [EvDistribution(a), EvDistribution(b)]) =>
|
||||
Some(runGenericOperation(FromDist(ToScore(KLDivergence(b)), a)))
|
||||
| ("isNormalized", [EvDistribution(dist)]) => Helpers.toBoolFn(IsNormalized, dist)
|
||||
| ("toPointSet", [EvDistribution(dist)]) => Helpers.toDistFn(ToPointSet, dist)
|
||||
Some(DistributionOperation.run(FromDist(ToScore(KLDivergence(b)), a), ~env))
|
||||
| ("isNormalized", [EvDistribution(dist)]) => Helpers.toBoolFn(IsNormalized, dist, ~env)
|
||||
| ("toPointSet", [EvDistribution(dist)]) => Helpers.toDistFn(ToPointSet, dist, ~env)
|
||||
| ("scaleLog", [EvDistribution(dist)]) =>
|
||||
Helpers.toDistFn(Scale(#Logarithm, MagicNumbers.Math.e), dist)
|
||||
| ("scaleLog10", [EvDistribution(dist)]) => Helpers.toDistFn(Scale(#Logarithm, 10.0), dist)
|
||||
Helpers.toDistFn(Scale(#Logarithm, MagicNumbers.Math.e), dist, ~env)
|
||||
| ("scaleLog10", [EvDistribution(dist)]) => Helpers.toDistFn(Scale(#Logarithm, 10.0), dist, ~env)
|
||||
| ("scaleLog", [EvDistribution(dist), EvNumber(float)]) =>
|
||||
Helpers.toDistFn(Scale(#Logarithm, float), dist)
|
||||
Helpers.toDistFn(Scale(#Logarithm, float), dist, ~env)
|
||||
| ("scaleLogWithThreshold", [EvDistribution(dist), EvNumber(base), EvNumber(eps)]) =>
|
||||
Helpers.toDistFn(Scale(#LogarithmWithThreshold(eps), base), dist)
|
||||
Helpers.toDistFn(Scale(#LogarithmWithThreshold(eps), base), dist, ~env)
|
||||
| ("scalePow", [EvDistribution(dist), EvNumber(float)]) =>
|
||||
Helpers.toDistFn(Scale(#Power, float), dist)
|
||||
Helpers.toDistFn(Scale(#Power, float), dist, ~env)
|
||||
| ("scaleExp", [EvDistribution(dist)]) =>
|
||||
Helpers.toDistFn(Scale(#Power, MagicNumbers.Math.e), dist)
|
||||
| ("cdf", [EvDistribution(dist), EvNumber(float)]) => Helpers.toFloatFn(#Cdf(float), dist)
|
||||
| ("pdf", [EvDistribution(dist), EvNumber(float)]) => Helpers.toFloatFn(#Pdf(float), dist)
|
||||
| ("inv", [EvDistribution(dist), EvNumber(float)]) => Helpers.toFloatFn(#Inv(float), dist)
|
||||
Helpers.toDistFn(Scale(#Power, MagicNumbers.Math.e), dist, ~env)
|
||||
| ("cdf", [EvDistribution(dist), EvNumber(float)]) => Helpers.toFloatFn(#Cdf(float), dist, ~env)
|
||||
| ("pdf", [EvDistribution(dist), EvNumber(float)]) => Helpers.toFloatFn(#Pdf(float), dist, ~env)
|
||||
| ("inv", [EvDistribution(dist), EvNumber(float)]) => Helpers.toFloatFn(#Inv(float), dist, ~env)
|
||||
| ("toSampleSet", [EvDistribution(dist), EvNumber(float)]) =>
|
||||
Helpers.toDistFn(ToSampleSet(Belt.Int.fromFloat(float)), dist)
|
||||
Helpers.toDistFn(ToSampleSet(Belt.Int.fromFloat(float)), dist, ~env)
|
||||
| ("toSampleSet", [EvDistribution(dist)]) =>
|
||||
Helpers.toDistFn(ToSampleSet(MagicNumbers.Environment.defaultSampleCount), dist)
|
||||
Helpers.toDistFn(ToSampleSet(env.sampleCount), dist, ~env)
|
||||
| ("fromSamples", [EvArray(inputArray)]) => {
|
||||
let _wrapInputErrors = x => SampleSetDist.NonNumericInput(x)
|
||||
let parsedArray = Helpers.parseNumberArray(inputArray)->E.R2.errMap(_wrapInputErrors)
|
||||
switch parsedArray {
|
||||
| Ok(array) => runGenericOperation(FromSamples(array))
|
||||
| Ok(array) => DistributionOperation.run(FromSamples(array), ~env)
|
||||
| Error(e) => GenDistError(SampleSetError(e))
|
||||
}->Some
|
||||
}
|
||||
| ("inspect", [EvDistribution(dist)]) => Helpers.toDistFn(Inspect, dist)
|
||||
| ("inspect", [EvDistribution(dist)]) => Helpers.toDistFn(Inspect, dist, ~env)
|
||||
| ("truncateLeft", [EvDistribution(dist), EvNumber(float)]) =>
|
||||
Helpers.toDistFn(Truncate(Some(float), None), dist)
|
||||
Helpers.toDistFn(Truncate(Some(float), None), dist, ~env)
|
||||
| ("truncateRight", [EvDistribution(dist), EvNumber(float)]) =>
|
||||
Helpers.toDistFn(Truncate(None, Some(float)), dist)
|
||||
Helpers.toDistFn(Truncate(None, Some(float)), dist, ~env)
|
||||
| ("truncate", [EvDistribution(dist), EvNumber(float1), EvNumber(float2)]) =>
|
||||
Helpers.toDistFn(Truncate(Some(float1), Some(float2)), dist)
|
||||
| ("mx" | "mixture", args) => Helpers.mixture(args)->Some
|
||||
Helpers.toDistFn(Truncate(Some(float1), Some(float2)), dist, ~env)
|
||||
| ("mx" | "mixture", args) => Helpers.mixture(args, ~env)->Some
|
||||
| ("log", [EvDistribution(a)]) =>
|
||||
Helpers.twoDiststoDistFn(
|
||||
Algebraic(AsDefault),
|
||||
"log",
|
||||
a,
|
||||
GenericDist.fromFloat(MagicNumbers.Math.e),
|
||||
~env,
|
||||
)->Some
|
||||
| ("log10", [EvDistribution(a)]) =>
|
||||
Helpers.twoDiststoDistFn(Algebraic(AsDefault), "log", a, GenericDist.fromFloat(10.0))->Some
|
||||
Helpers.twoDiststoDistFn(
|
||||
Algebraic(AsDefault),
|
||||
"log",
|
||||
a,
|
||||
GenericDist.fromFloat(10.0),
|
||||
~env,
|
||||
)->Some
|
||||
| ("unaryMinus", [EvDistribution(a)]) =>
|
||||
Helpers.twoDiststoDistFn(Algebraic(AsDefault), "multiply", a, GenericDist.fromFloat(-1.0))->Some
|
||||
Helpers.twoDiststoDistFn(
|
||||
Algebraic(AsDefault),
|
||||
"multiply",
|
||||
a,
|
||||
GenericDist.fromFloat(-1.0),
|
||||
~env,
|
||||
)->Some
|
||||
| (("add" | "multiply" | "subtract" | "divide" | "pow" | "log") as arithmetic, [_, _] as args) =>
|
||||
Helpers.catchAndConvertTwoArgsToDists(args)->E.O2.fmap(((fst, snd)) =>
|
||||
Helpers.twoDiststoDistFn(Algebraic(AsDefault), arithmetic, fst, snd)
|
||||
Helpers.twoDiststoDistFn(Algebraic(AsDefault), arithmetic, fst, snd, ~env)
|
||||
)
|
||||
| (
|
||||
("dotAdd"
|
||||
|
@ -296,7 +316,7 @@ let dispatchToGenericOutput = (call: ExpressionValue.functionCall, _environment)
|
|||
[_, _] as args,
|
||||
) =>
|
||||
Helpers.catchAndConvertTwoArgsToDists(args)->E.O2.fmap(((fst, snd)) =>
|
||||
Helpers.twoDiststoDistFn(Pointwise, arithmetic, fst, snd)
|
||||
Helpers.twoDiststoDistFn(Pointwise, arithmetic, fst, snd, ~env)
|
||||
)
|
||||
| ("dotExp", [EvDistribution(a)]) =>
|
||||
Helpers.twoDiststoDistFn(
|
||||
|
@ -304,6 +324,7 @@ let dispatchToGenericOutput = (call: ExpressionValue.functionCall, _environment)
|
|||
"dotPow",
|
||||
GenericDist.fromFloat(MagicNumbers.Math.e),
|
||||
a,
|
||||
~env,
|
||||
)->Some
|
||||
| _ => None
|
||||
}
|
||||
|
|
|
@ -1,4 +1,3 @@
|
|||
let defaultEnv: DistributionOperation.env
|
||||
let dispatch: (
|
||||
ReducerInterface_ExpressionValue.functionCall,
|
||||
ReducerInterface_ExpressionValue.environment,
|
||||
|
|
|
@ -77,7 +77,7 @@ let distributionErrorToString = DistributionTypes.Error.toString
|
|||
type lambdaValue = ReducerInterface_ExpressionValue.lambdaValue
|
||||
|
||||
@genType
|
||||
let defaultSamplingEnv = ReducerInterface_GenericDistribution.defaultEnv
|
||||
let defaultSamplingEnv = DistributionOperation.defaultEnv
|
||||
|
||||
@genType
|
||||
type environment = ReducerInterface_ExpressionValue.environment
|
||||
|
|
|
@ -453,6 +453,44 @@ module PointwiseCombination = {
|
|||
T.filterOkYs(newXs, newYs)->Ok
|
||||
}
|
||||
|
||||
/* *Dead code*: Nuño wrote this function to try to increase precision, but it didn't work.
|
||||
If another traveler comes through with a similar idea, we hope this implementation will help them.
|
||||
By "enrich" we mean to increase granularity.
|
||||
*/
|
||||
let enrichXyShape = (t: T.t): T.t => {
|
||||
let defaultEnrichmentFactor = 10
|
||||
let length = E.A.length(t.xs)
|
||||
let points =
|
||||
length < MagicNumbers.Environment.defaultXYPointLength
|
||||
? defaultEnrichmentFactor * MagicNumbers.Environment.defaultXYPointLength / length
|
||||
: defaultEnrichmentFactor
|
||||
|
||||
let getInBetween = (x1: float, x2: float): array<float> => {
|
||||
if abs_float(x1 -. x2) < 2.0 *. MagicNumbers.Epsilon.seven {
|
||||
[x1]
|
||||
} else {
|
||||
let newPointsArray = Belt.Array.makeBy(points - 1, i => i)
|
||||
// don't repeat the x2 point, it will be gotten in the next iteration.
|
||||
let result = Js.Array.mapi((pos, i) =>
|
||||
if i == 0 {
|
||||
x1
|
||||
} else {
|
||||
let points' = Belt.Float.fromInt(points)
|
||||
let pos' = Belt.Float.fromInt(pos)
|
||||
x1 *. (points' -. pos') /. points' +. x2 *. pos' /. points'
|
||||
}
|
||||
, newPointsArray)
|
||||
result
|
||||
}
|
||||
}
|
||||
let newXsUnflattened = Js.Array.mapi(
|
||||
(x, i) => i < length - 2 ? getInBetween(x, t.xs[i + 1]) : [x],
|
||||
t.xs,
|
||||
)
|
||||
let newXs = Belt.Array.concatMany(newXsUnflattened)
|
||||
let newYs = E.A.fmap(x => XtoY.linear(x, t), newXs)
|
||||
{xs: newXs, ys: newYs}
|
||||
}
|
||||
// This function is used for klDivergence
|
||||
let combineAlongSupportOfSecondArgument: (
|
||||
(float, float) => result<float, Operation.Error.t>,
|
||||
|
|
16
yarn.lock
16
yarn.lock
|
@ -3726,10 +3726,10 @@
|
|||
"@testing-library/dom" "^8.5.0"
|
||||
"@types/react-dom" "^18.0.0"
|
||||
|
||||
"@testing-library/user-event@^14.1.1":
|
||||
version "14.1.1"
|
||||
resolved "https://registry.yarnpkg.com/@testing-library/user-event/-/user-event-14.1.1.tgz#e1ff6118896e4b22af31e5ea2f9da956adde23d8"
|
||||
integrity sha512-XrjH/iEUqNl9lF2HX9YhPNV7Amntkcnpw0Bo1KkRzowNDcgSN9i0nm4Q8Oi5wupgdfPaJNMAWa61A+voD6Kmwg==
|
||||
"@testing-library/user-event@^14.2.0":
|
||||
version "14.2.0"
|
||||
resolved "https://registry.yarnpkg.com/@testing-library/user-event/-/user-event-14.2.0.tgz#8293560f8f80a00383d6c755ec3e0b918acb1683"
|
||||
integrity sha512-+hIlG4nJS6ivZrKnOP7OGsDu9Fxmryj9vCl8x0ZINtTJcCHs2zLsYif5GzuRiBF2ck5GZG2aQr7Msg+EHlnYVQ==
|
||||
|
||||
"@tootallnate/once@1":
|
||||
version "1.1.2"
|
||||
|
@ -4038,10 +4038,10 @@
|
|||
"@types/node" "*"
|
||||
form-data "^3.0.0"
|
||||
|
||||
"@types/node@*", "@types/node@^17.0.31", "@types/node@^17.0.5":
|
||||
version "17.0.31"
|
||||
resolved "https://registry.yarnpkg.com/@types/node/-/node-17.0.31.tgz#a5bb84ecfa27eec5e1c802c6bbf8139bdb163a5d"
|
||||
integrity sha512-AR0x5HbXGqkEx9CadRH3EBYx/VkiUgZIhP4wvPn/+5KIsgpNoyFaRlVe0Zlx9gRtg8fA06a9tskE2MSN7TcG4Q==
|
||||
"@types/node@*", "@types/node@^17.0.32", "@types/node@^17.0.5":
|
||||
version "17.0.32"
|
||||
resolved "https://registry.yarnpkg.com/@types/node/-/node-17.0.32.tgz#51d59d7a90ef2d0ae961791e0900cad2393a0149"
|
||||
integrity sha512-eAIcfAvhf/BkHcf4pkLJ7ECpBAhh9kcxRBpip9cTiO+hf+aJrsxYxBeS6OXvOd9WqNAJmavXVpZvY1rBjNsXmw==
|
||||
|
||||
"@types/node@^14.0.10":
|
||||
version "14.18.16"
|
||||
|
|
Loading…
Reference in New Issue
Block a user