diff --git a/packages/components/package.json b/packages/components/package.json index b5f1ab7c..e8513982 100644 --- a/packages/components/package.json +++ b/packages/components/package.json @@ -28,10 +28,10 @@ "@storybook/react": "^6.4.22", "@testing-library/jest-dom": "^5.16.4", "@testing-library/react": "^13.2.0", - "@testing-library/user-event": "^14.1.1", + "@testing-library/user-event": "^14.2.0", "@types/jest": "^27.5.0", "@types/lodash": "^4.14.182", - "@types/node": "^17.0.31", + "@types/node": "^17.0.32", "@types/react": "^18.0.9", "@types/react-dom": "^18.0.2", "@types/styled-components": "^5.1.24", diff --git a/packages/components/src/components/FunctionChart.tsx b/packages/components/src/components/FunctionChart.tsx index 2ee71721..2fd4587d 100644 --- a/packages/components/src/components/FunctionChart.tsx +++ b/packages/components/src/components/FunctionChart.tsx @@ -76,28 +76,32 @@ export const FunctionChart: React.FC = ({ chartSettings.count ); type point = { x: number; value: result }; - let valueData: point[] = data1.map((x) => { - let result = runForeign(fn, [x], environment); - if (result.tag === "Ok") { - if (result.value.tag == "distribution") { - return { x, value: { tag: "Ok", value: result.value.value } }; - } else { - return { - x, - value: { - tag: "Error", - value: - "Cannot currently render functions that don't return distributions", - }, - }; - } - } else { - return { - x, - value: { tag: "Error", value: errorValueToString(result.value) }, - }; - } - }); + let valueData: point[] = React.useMemo( + () => + data1.map((x) => { + let result = runForeign(fn, [x], environment); + if (result.tag === "Ok") { + if (result.value.tag == "distribution") { + return { x, value: { tag: "Ok", value: result.value.value } }; + } else { + return { + x, + value: { + tag: "Error", + value: + "Cannot currently render functions that don't return distributions", + }, + }; + } + } else { + return { + x, + value: { tag: "Error", value: errorValueToString(result.value) }, + }; + } + }), + [environment, fn] + ); let initialPartition: [ { x: number; value: Distribution }[], @@ -141,10 +145,10 @@ export const FunctionChart: React.FC = ({ /> {showChart} {_.entries(groupedErrors).map(([errorName, errorPoints]) => ( - + Values:{" "} {errorPoints - .map((r) => ) + .map((r, i) => ) .reduce((a, b) => ( <> {a}, {b} diff --git a/packages/components/src/components/SquiggleChart.tsx b/packages/components/src/components/SquiggleChart.tsx index 241cd772..a54ac64d 100644 --- a/packages/components/src/components/SquiggleChart.tsx +++ b/packages/components/src/components/SquiggleChart.tsx @@ -148,8 +148,9 @@ const SquiggleItem: React.FC = ({ case "array": return ( - {expression.value.map((r) => ( + {expression.value.map((r, i) => ( = ({ return ( {Object.entries(expression.value).map(([key, r]) => ( - <> +
{key} = ({ chartSettings={chartSettings} environment={environment} /> - +
))}
); case "arraystring": return ( - {expression.value.map((r) => `"${r}"`)} + {expression.value.map((r) => `"${r}"`).join(", ")} ); case "lambda": diff --git a/packages/squiggle-lang/__tests__/Distributions/GenericDist_Fixtures.res b/packages/squiggle-lang/__tests__/Distributions/GenericDist_Fixtures.res index 8e315599..d184b61b 100644 --- a/packages/squiggle-lang/__tests__/Distributions/GenericDist_Fixtures.res +++ b/packages/squiggle-lang/__tests__/Distributions/GenericDist_Fixtures.res @@ -11,6 +11,7 @@ let triangularDist: DistributionTypes.genericDist = Symbolic( ) let exponentialDist: DistributionTypes.genericDist = Symbolic(#Exponential({rate: 2.0})) let uniformDist: DistributionTypes.genericDist = Symbolic(#Uniform({low: 9.0, high: 10.0})) +let uniformDist2: DistributionTypes.genericDist = Symbolic(#Uniform({low: 8.0, high: 11.0})) let floatDist: DistributionTypes.genericDist = Symbolic(#Float(1e1)) exception KlFailed diff --git a/packages/squiggle-lang/__tests__/Distributions/KlDivergence_test.res b/packages/squiggle-lang/__tests__/Distributions/KlDivergence_test.res index 96e95899..fc528e08 100644 --- a/packages/squiggle-lang/__tests__/Distributions/KlDivergence_test.res +++ b/packages/squiggle-lang/__tests__/Distributions/KlDivergence_test.res @@ -3,9 +3,15 @@ open Expect open TestHelpers open GenericDist_Fixtures +// integral from low to high of 1 / (high - low) log(normal(mean, stdev)(x) / (1 / (high - low))) dx +let klNormalUniform = (mean, stdev, low, high): float => + -.Js.Math.log((high -. low) /. Js.Math.sqrt(2.0 *. MagicNumbers.Math.pi *. stdev ** 2.0)) +. + 1.0 /. + stdev ** 2.0 *. + (mean ** 2.0 -. (high +. low) *. mean +. (low ** 2.0 +. high *. low +. high ** 2.0) /. 3.0) + describe("klDivergence: continuous -> continuous -> float", () => { let klDivergence = DistributionOperation.Constructors.klDivergence(~env) - exception KlFailed let testUniform = (lowAnswer, highAnswer, lowPrediction, highPrediction) => { test("of two uniforms is equal to the analytic expression", () => { @@ -59,6 +65,20 @@ describe("klDivergence: continuous -> continuous -> float", () => { } } }) + + test("of a normal and a uniform is equal to the formula", () => { + let prediction = normalDist10 + let answer = uniformDist + let kl = klDivergence(prediction, answer) + let analyticalKl = klNormalUniform(10.0, 2.0, 9.0, 10.0) + switch kl { + | Ok(kl') => kl'->expect->toBeSoCloseTo(analyticalKl, ~digits=1) + | Error(err) => { + Js.Console.log(DistributionTypes.Error.toString(err)) + raise(KlFailed) + } + } + }) }) describe("klDivergence: discrete -> discrete -> float", () => { @@ -96,6 +116,64 @@ describe("klDivergence: discrete -> discrete -> float", () => { }) }) +describe("klDivergence: mixed -> mixed -> float", () => { + let klDivergence = DistributionOperation.Constructors.klDivergence(~env) + let mixture' = a => DistributionTypes.DistributionOperation.Mixture(a) + let mixture = a => { + let dist' = a->mixture'->run + switch dist' { + | Dist(dist) => dist + | _ => raise(MixtureFailed) + } + } + let a = [(point1, 1.0), (uniformDist, 1.0)]->mixture + let b = [(point1, 1.0), (floatDist, 1.0), (normalDist10, 1.0)]->mixture + let c = [(point1, 1.0), (point2, 1.0), (point3, 1.0), (uniformDist, 1.0)]->mixture + let d = + [(point1, 1.0), (point2, 1.0), (point3, 1.0), (floatDist, 1.0), (uniformDist2, 1.0)]->mixture + + test("finite klDivergence produces correct answer", () => { + let prediction = b + let answer = a + let kl = klDivergence(prediction, answer) + // high = 10; low = 9; mean = 10; stdev = 2 + let analyticalKlContinuousPart = klNormalUniform(10.0, 2.0, 9.0, 10.0) /. 2.0 + let analyticalKlDiscretePart = 1.0 /. 2.0 *. Js.Math.log(2.0 /. 1.0) + switch kl { + | Ok(kl') => + kl'->expect->toBeSoCloseTo(analyticalKlContinuousPart +. analyticalKlDiscretePart, ~digits=1) + | Error(err) => + Js.Console.log(DistributionTypes.Error.toString(err)) + raise(KlFailed) + } + }) + test("returns infinity when infinite", () => { + let prediction = a + let answer = b + let kl = klDivergence(prediction, answer) + switch kl { + | Ok(kl') => kl'->expect->toEqual(infinity) + | Error(err) => + Js.Console.log(DistributionTypes.Error.toString(err)) + raise(KlFailed) + } + }) + test("finite klDivergence produces correct answer", () => { + let prediction = d + let answer = c + let kl = klDivergence(prediction, answer) + let analyticalKlContinuousPart = Js.Math.log((11.0 -. 8.0) /. (10.0 -. 9.0)) /. 4.0 // 4 = length of c' array + let analyticalKlDiscretePart = 3.0 /. 4.0 *. Js.Math.log(4.0 /. 3.0) + switch kl { + | Ok(kl') => + kl'->expect->toBeSoCloseTo(analyticalKlContinuousPart +. analyticalKlDiscretePart, ~digits=1) + | Error(err) => + Js.Console.log(DistributionTypes.Error.toString(err)) + raise(KlFailed) + } + }) +}) + describe("combineAlongSupportOfSecondArgument0", () => { // This tests the version of the function that we're NOT using. Haven't deleted the test in case we use the code later. test("test on two uniforms", _ => { diff --git a/packages/squiggle-lang/package.json b/packages/squiggle-lang/package.json index ff786b91..d21f2eb6 100644 --- a/packages/squiggle-lang/package.json +++ b/packages/squiggle-lang/package.json @@ -1,6 +1,6 @@ { "name": "@quri/squiggle-lang", - "version": "0.2.8", + "version": "0.2.9", "homepage": "https://squiggle-language.com", "license": "MIT", "scripts": { diff --git a/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Mixed.res b/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Mixed.res index 05465728..7bbe2065 100644 --- a/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Mixed.res +++ b/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Mixed.res @@ -302,10 +302,9 @@ module T = Dist({ } let klDivergence = (prediction: t, answer: t) => { - Error(Operation.NotYetImplemented) - // combinePointwise(PointSetDist_Scoring.KLDivergence.integrand, prediction, answer) |> E.R.fmap( - // integralEndY, - // ) + let klDiscretePart = Discrete.T.klDivergence(prediction.discrete, answer.discrete) + let klContinuousPart = Continuous.T.klDivergence(prediction.continuous, answer.continuous) + E.R.merge(klDiscretePart, klContinuousPart)->E.R2.fmap(t => fst(t) +. snd(t)) } }) diff --git a/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/PointSetDist.res b/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/PointSetDist.res index 1879ebdd..db47d1e1 100644 --- a/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/PointSetDist.res +++ b/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/PointSetDist.res @@ -200,6 +200,7 @@ module T = Dist({ switch (t1, t2) { | (Continuous(t1), Continuous(t2)) => Continuous.T.klDivergence(t1, t2) | (Discrete(t1), Discrete(t2)) => Discrete.T.klDivergence(t1, t2) + | (Mixed(t1), Mixed(t2)) => Mixed.T.klDivergence(t1, t2) | _ => Error(NotYetImplemented) } }) diff --git a/packages/squiggle-lang/src/rescript/ReducerInterface/ReducerInterface_GenericDistribution.res b/packages/squiggle-lang/src/rescript/ReducerInterface/ReducerInterface_GenericDistribution.res index 37bffeef..448d7f1b 100644 --- a/packages/squiggle-lang/src/rescript/ReducerInterface/ReducerInterface_GenericDistribution.res +++ b/packages/squiggle-lang/src/rescript/ReducerInterface/ReducerInterface_GenericDistribution.res @@ -1,13 +1,6 @@ module ExpressionValue = ReducerInterface_ExpressionValue type expressionValue = ReducerInterface_ExpressionValue.expressionValue -let defaultEnv: DistributionOperation.env = { - sampleCount: MagicNumbers.Environment.defaultSampleCount, - xyPointLength: MagicNumbers.Environment.defaultXYPointLength, -} - -let runGenericOperation = DistributionOperation.run(~env=defaultEnv) - module Helpers = { let arithmeticMap = r => switch r { @@ -39,37 +32,44 @@ module Helpers = { let toFloatFn = ( fnCall: DistributionTypes.DistributionOperation.toFloat, dist: DistributionTypes.genericDist, + ~env: DistributionOperation.env, ) => { FromDist(DistributionTypes.DistributionOperation.ToFloat(fnCall), dist) - ->runGenericOperation + ->DistributionOperation.run(~env) ->Some } let toStringFn = ( fnCall: DistributionTypes.DistributionOperation.toString, dist: DistributionTypes.genericDist, + ~env: DistributionOperation.env, ) => { FromDist(DistributionTypes.DistributionOperation.ToString(fnCall), dist) - ->runGenericOperation + ->DistributionOperation.run(~env) ->Some } let toBoolFn = ( fnCall: DistributionTypes.DistributionOperation.toBool, dist: DistributionTypes.genericDist, + ~env: DistributionOperation.env, ) => { FromDist(DistributionTypes.DistributionOperation.ToBool(fnCall), dist) - ->runGenericOperation + ->DistributionOperation.run(~env) ->Some } - let toDistFn = (fnCall: DistributionTypes.DistributionOperation.toDist, dist) => { + let toDistFn = ( + fnCall: DistributionTypes.DistributionOperation.toDist, + dist, + ~env: DistributionOperation.env, + ) => { FromDist(DistributionTypes.DistributionOperation.ToDist(fnCall), dist) - ->runGenericOperation + ->DistributionOperation.run(~env) ->Some } - let twoDiststoDistFn = (direction, arithmetic, dist1, dist2) => { + let twoDiststoDistFn = (direction, arithmetic, dist1, dist2, ~env: DistributionOperation.env) => { FromDist( DistributionTypes.DistributionOperation.ToDistCombination( direction, @@ -77,7 +77,7 @@ module Helpers = { #Dist(dist2), ), dist1, - )->runGenericOperation + )->DistributionOperation.run(~env) } let parseNumber = (args: expressionValue): Belt.Result.t => @@ -104,33 +104,38 @@ module Helpers = { let mixtureWithGivenWeights = ( distributions: array, weights: array, + ~env: DistributionOperation.env, ): DistributionOperation.outputType => E.A.length(distributions) == E.A.length(weights) - ? Mixture(Belt.Array.zip(distributions, weights))->runGenericOperation + ? Mixture(Belt.Array.zip(distributions, weights))->DistributionOperation.run(~env) : GenDistError( ArgumentError("Error, mixture call has different number of distributions and weights"), ) let mixtureWithDefaultWeights = ( distributions: array, + ~env: DistributionOperation.env, ): DistributionOperation.outputType => { let length = E.A.length(distributions) let weights = Belt.Array.make(length, 1.0 /. Belt.Int.toFloat(length)) - mixtureWithGivenWeights(distributions, weights) + mixtureWithGivenWeights(distributions, weights, ~env) } - let mixture = (args: array): DistributionOperation.outputType => { + let mixture = ( + args: array, + ~env: DistributionOperation.env, + ): DistributionOperation.outputType => { let error = (err: string): DistributionOperation.outputType => err->DistributionTypes.ArgumentError->GenDistError switch args { | [EvArray(distributions)] => switch parseDistributionArray(distributions) { - | Ok(distrs) => mixtureWithDefaultWeights(distrs) + | Ok(distrs) => mixtureWithDefaultWeights(distrs, ~env) | Error(err) => error(err) } | [EvArray(distributions), EvArray(weights)] => switch (parseDistributionArray(distributions), parseNumberArray(weights)) { - | (Ok(distrs), Ok(wghts)) => mixtureWithGivenWeights(distrs, wghts) + | (Ok(distrs), Ok(wghts)) => mixtureWithGivenWeights(distrs, wghts, ~env) | (Error(err), Ok(_)) => error(err) | (Ok(_), Error(err)) => error(err) | (Error(err1), Error(err2)) => error(`${err1}|${err2}`) @@ -143,14 +148,14 @@ module Helpers = { Belt.Array.slice(args, ~offset=0, ~len=E.A.length(args) - 1), ) switch E.R.merge(distributions, weights) { - | Ok(d, w) => mixtureWithGivenWeights(d, w) + | Ok(d, w) => mixtureWithGivenWeights(d, w, ~env) | Error(err) => error(err) } } | Some(EvNumber(_)) | Some(EvDistribution(_)) => switch parseDistributionArray(args) { - | Ok(distributions) => mixtureWithDefaultWeights(distributions) + | Ok(distributions) => mixtureWithDefaultWeights(distributions, ~env) | Error(err) => error(err) } | _ => error("Last argument of mx must be array or distribution") @@ -193,9 +198,10 @@ module SymbolicConstructors = { } } -let dispatchToGenericOutput = (call: ExpressionValue.functionCall, _environment): option< - DistributionOperation.outputType, -> => { +let dispatchToGenericOutput = ( + call: ExpressionValue.functionCall, + env: DistributionOperation.env, +): option => { let (fnName, args) = call switch (fnName, args) { | ("exponential" as fnName, [EvNumber(f)]) => @@ -215,16 +221,16 @@ let dispatchToGenericOutput = (call: ExpressionValue.functionCall, _environment) SymbolicConstructors.threeFloat(fnName) ->E.R.bind(r => r(f1, f2, f3)) ->SymbolicConstructors.symbolicResultToOutput - | ("sample", [EvDistribution(dist)]) => Helpers.toFloatFn(#Sample, dist) + | ("sample", [EvDistribution(dist)]) => Helpers.toFloatFn(#Sample, dist, ~env) | ("sampleN", [EvDistribution(dist), EvNumber(n)]) => Some( FloatArray(GenericDist.sampleN(dist, Belt.Int.fromFloat(n))), ) - | ("mean", [EvDistribution(dist)]) => Helpers.toFloatFn(#Mean, dist) - | ("integralSum", [EvDistribution(dist)]) => Helpers.toFloatFn(#IntegralSum, dist) - | ("toString", [EvDistribution(dist)]) => Helpers.toStringFn(ToString, dist) - | ("toSparkline", [EvDistribution(dist)]) => Helpers.toStringFn(ToSparkline(20), dist) + | ("mean", [EvDistribution(dist)]) => Helpers.toFloatFn(#Mean, dist, ~env) + | ("integralSum", [EvDistribution(dist)]) => Helpers.toFloatFn(#IntegralSum, dist, ~env) + | ("toString", [EvDistribution(dist)]) => Helpers.toStringFn(ToString, dist, ~env) + | ("toSparkline", [EvDistribution(dist)]) => Helpers.toStringFn(ToSparkline(20), dist, ~env) | ("toSparkline", [EvDistribution(dist), EvNumber(n)]) => - Helpers.toStringFn(ToSparkline(Belt.Float.toInt(n)), dist) + Helpers.toStringFn(ToSparkline(Belt.Float.toInt(n)), dist, ~env) | ("exp", [EvDistribution(a)]) => // https://mathjs.org/docs/reference/functions/exp.html Helpers.twoDiststoDistFn( @@ -232,60 +238,74 @@ let dispatchToGenericOutput = (call: ExpressionValue.functionCall, _environment) "pow", GenericDist.fromFloat(MagicNumbers.Math.e), a, + ~env, )->Some - | ("normalize", [EvDistribution(dist)]) => Helpers.toDistFn(Normalize, dist) + | ("normalize", [EvDistribution(dist)]) => Helpers.toDistFn(Normalize, dist, ~env) | ("klDivergence", [EvDistribution(a), EvDistribution(b)]) => - Some(runGenericOperation(FromDist(ToScore(KLDivergence(b)), a))) - | ("isNormalized", [EvDistribution(dist)]) => Helpers.toBoolFn(IsNormalized, dist) - | ("toPointSet", [EvDistribution(dist)]) => Helpers.toDistFn(ToPointSet, dist) + Some(DistributionOperation.run(FromDist(ToScore(KLDivergence(b)), a), ~env)) + | ("isNormalized", [EvDistribution(dist)]) => Helpers.toBoolFn(IsNormalized, dist, ~env) + | ("toPointSet", [EvDistribution(dist)]) => Helpers.toDistFn(ToPointSet, dist, ~env) | ("scaleLog", [EvDistribution(dist)]) => - Helpers.toDistFn(Scale(#Logarithm, MagicNumbers.Math.e), dist) - | ("scaleLog10", [EvDistribution(dist)]) => Helpers.toDistFn(Scale(#Logarithm, 10.0), dist) + Helpers.toDistFn(Scale(#Logarithm, MagicNumbers.Math.e), dist, ~env) + | ("scaleLog10", [EvDistribution(dist)]) => Helpers.toDistFn(Scale(#Logarithm, 10.0), dist, ~env) | ("scaleLog", [EvDistribution(dist), EvNumber(float)]) => - Helpers.toDistFn(Scale(#Logarithm, float), dist) + Helpers.toDistFn(Scale(#Logarithm, float), dist, ~env) | ("scaleLogWithThreshold", [EvDistribution(dist), EvNumber(base), EvNumber(eps)]) => - Helpers.toDistFn(Scale(#LogarithmWithThreshold(eps), base), dist) + Helpers.toDistFn(Scale(#LogarithmWithThreshold(eps), base), dist, ~env) | ("scalePow", [EvDistribution(dist), EvNumber(float)]) => - Helpers.toDistFn(Scale(#Power, float), dist) + Helpers.toDistFn(Scale(#Power, float), dist, ~env) | ("scaleExp", [EvDistribution(dist)]) => - Helpers.toDistFn(Scale(#Power, MagicNumbers.Math.e), dist) - | ("cdf", [EvDistribution(dist), EvNumber(float)]) => Helpers.toFloatFn(#Cdf(float), dist) - | ("pdf", [EvDistribution(dist), EvNumber(float)]) => Helpers.toFloatFn(#Pdf(float), dist) - | ("inv", [EvDistribution(dist), EvNumber(float)]) => Helpers.toFloatFn(#Inv(float), dist) + Helpers.toDistFn(Scale(#Power, MagicNumbers.Math.e), dist, ~env) + | ("cdf", [EvDistribution(dist), EvNumber(float)]) => Helpers.toFloatFn(#Cdf(float), dist, ~env) + | ("pdf", [EvDistribution(dist), EvNumber(float)]) => Helpers.toFloatFn(#Pdf(float), dist, ~env) + | ("inv", [EvDistribution(dist), EvNumber(float)]) => Helpers.toFloatFn(#Inv(float), dist, ~env) | ("toSampleSet", [EvDistribution(dist), EvNumber(float)]) => - Helpers.toDistFn(ToSampleSet(Belt.Int.fromFloat(float)), dist) + Helpers.toDistFn(ToSampleSet(Belt.Int.fromFloat(float)), dist, ~env) | ("toSampleSet", [EvDistribution(dist)]) => - Helpers.toDistFn(ToSampleSet(MagicNumbers.Environment.defaultSampleCount), dist) + Helpers.toDistFn(ToSampleSet(env.sampleCount), dist, ~env) | ("fromSamples", [EvArray(inputArray)]) => { let _wrapInputErrors = x => SampleSetDist.NonNumericInput(x) let parsedArray = Helpers.parseNumberArray(inputArray)->E.R2.errMap(_wrapInputErrors) switch parsedArray { - | Ok(array) => runGenericOperation(FromSamples(array)) + | Ok(array) => DistributionOperation.run(FromSamples(array), ~env) | Error(e) => GenDistError(SampleSetError(e)) }->Some } - | ("inspect", [EvDistribution(dist)]) => Helpers.toDistFn(Inspect, dist) + | ("inspect", [EvDistribution(dist)]) => Helpers.toDistFn(Inspect, dist, ~env) | ("truncateLeft", [EvDistribution(dist), EvNumber(float)]) => - Helpers.toDistFn(Truncate(Some(float), None), dist) + Helpers.toDistFn(Truncate(Some(float), None), dist, ~env) | ("truncateRight", [EvDistribution(dist), EvNumber(float)]) => - Helpers.toDistFn(Truncate(None, Some(float)), dist) + Helpers.toDistFn(Truncate(None, Some(float)), dist, ~env) | ("truncate", [EvDistribution(dist), EvNumber(float1), EvNumber(float2)]) => - Helpers.toDistFn(Truncate(Some(float1), Some(float2)), dist) - | ("mx" | "mixture", args) => Helpers.mixture(args)->Some + Helpers.toDistFn(Truncate(Some(float1), Some(float2)), dist, ~env) + | ("mx" | "mixture", args) => Helpers.mixture(args, ~env)->Some | ("log", [EvDistribution(a)]) => Helpers.twoDiststoDistFn( Algebraic(AsDefault), "log", a, GenericDist.fromFloat(MagicNumbers.Math.e), + ~env, )->Some | ("log10", [EvDistribution(a)]) => - Helpers.twoDiststoDistFn(Algebraic(AsDefault), "log", a, GenericDist.fromFloat(10.0))->Some + Helpers.twoDiststoDistFn( + Algebraic(AsDefault), + "log", + a, + GenericDist.fromFloat(10.0), + ~env, + )->Some | ("unaryMinus", [EvDistribution(a)]) => - Helpers.twoDiststoDistFn(Algebraic(AsDefault), "multiply", a, GenericDist.fromFloat(-1.0))->Some + Helpers.twoDiststoDistFn( + Algebraic(AsDefault), + "multiply", + a, + GenericDist.fromFloat(-1.0), + ~env, + )->Some | (("add" | "multiply" | "subtract" | "divide" | "pow" | "log") as arithmetic, [_, _] as args) => Helpers.catchAndConvertTwoArgsToDists(args)->E.O2.fmap(((fst, snd)) => - Helpers.twoDiststoDistFn(Algebraic(AsDefault), arithmetic, fst, snd) + Helpers.twoDiststoDistFn(Algebraic(AsDefault), arithmetic, fst, snd, ~env) ) | ( ("dotAdd" @@ -296,7 +316,7 @@ let dispatchToGenericOutput = (call: ExpressionValue.functionCall, _environment) [_, _] as args, ) => Helpers.catchAndConvertTwoArgsToDists(args)->E.O2.fmap(((fst, snd)) => - Helpers.twoDiststoDistFn(Pointwise, arithmetic, fst, snd) + Helpers.twoDiststoDistFn(Pointwise, arithmetic, fst, snd, ~env) ) | ("dotExp", [EvDistribution(a)]) => Helpers.twoDiststoDistFn( @@ -304,6 +324,7 @@ let dispatchToGenericOutput = (call: ExpressionValue.functionCall, _environment) "dotPow", GenericDist.fromFloat(MagicNumbers.Math.e), a, + ~env, )->Some | _ => None } diff --git a/packages/squiggle-lang/src/rescript/ReducerInterface/ReducerInterface_GenericDistribution.resi b/packages/squiggle-lang/src/rescript/ReducerInterface/ReducerInterface_GenericDistribution.resi index 038f4479..7f26a610 100644 --- a/packages/squiggle-lang/src/rescript/ReducerInterface/ReducerInterface_GenericDistribution.resi +++ b/packages/squiggle-lang/src/rescript/ReducerInterface/ReducerInterface_GenericDistribution.resi @@ -1,4 +1,3 @@ -let defaultEnv: DistributionOperation.env let dispatch: ( ReducerInterface_ExpressionValue.functionCall, ReducerInterface_ExpressionValue.environment, diff --git a/packages/squiggle-lang/src/rescript/TypescriptInterface.res b/packages/squiggle-lang/src/rescript/TypescriptInterface.res index 13763e72..93af9832 100644 --- a/packages/squiggle-lang/src/rescript/TypescriptInterface.res +++ b/packages/squiggle-lang/src/rescript/TypescriptInterface.res @@ -77,7 +77,7 @@ let distributionErrorToString = DistributionTypes.Error.toString type lambdaValue = ReducerInterface_ExpressionValue.lambdaValue @genType -let defaultSamplingEnv = ReducerInterface_GenericDistribution.defaultEnv +let defaultSamplingEnv = DistributionOperation.defaultEnv @genType type environment = ReducerInterface_ExpressionValue.environment diff --git a/packages/squiggle-lang/src/rescript/Utility/XYShape.res b/packages/squiggle-lang/src/rescript/Utility/XYShape.res index 60d0bbde..b4758dfd 100644 --- a/packages/squiggle-lang/src/rescript/Utility/XYShape.res +++ b/packages/squiggle-lang/src/rescript/Utility/XYShape.res @@ -453,6 +453,44 @@ module PointwiseCombination = { T.filterOkYs(newXs, newYs)->Ok } + /* *Dead code*: Nuño wrote this function to try to increase precision, but it didn't work. + If another traveler comes through with a similar idea, we hope this implementation will help them. + By "enrich" we mean to increase granularity. + */ + let enrichXyShape = (t: T.t): T.t => { + let defaultEnrichmentFactor = 10 + let length = E.A.length(t.xs) + let points = + length < MagicNumbers.Environment.defaultXYPointLength + ? defaultEnrichmentFactor * MagicNumbers.Environment.defaultXYPointLength / length + : defaultEnrichmentFactor + + let getInBetween = (x1: float, x2: float): array => { + if abs_float(x1 -. x2) < 2.0 *. MagicNumbers.Epsilon.seven { + [x1] + } else { + let newPointsArray = Belt.Array.makeBy(points - 1, i => i) + // don't repeat the x2 point, it will be gotten in the next iteration. + let result = Js.Array.mapi((pos, i) => + if i == 0 { + x1 + } else { + let points' = Belt.Float.fromInt(points) + let pos' = Belt.Float.fromInt(pos) + x1 *. (points' -. pos') /. points' +. x2 *. pos' /. points' + } + , newPointsArray) + result + } + } + let newXsUnflattened = Js.Array.mapi( + (x, i) => i < length - 2 ? getInBetween(x, t.xs[i + 1]) : [x], + t.xs, + ) + let newXs = Belt.Array.concatMany(newXsUnflattened) + let newYs = E.A.fmap(x => XtoY.linear(x, t), newXs) + {xs: newXs, ys: newYs} + } // This function is used for klDivergence let combineAlongSupportOfSecondArgument: ( (float, float) => result, diff --git a/yarn.lock b/yarn.lock index 351a1e62..81c8c21f 100644 --- a/yarn.lock +++ b/yarn.lock @@ -3726,10 +3726,10 @@ "@testing-library/dom" "^8.5.0" "@types/react-dom" "^18.0.0" -"@testing-library/user-event@^14.1.1": - version "14.1.1" - resolved "https://registry.yarnpkg.com/@testing-library/user-event/-/user-event-14.1.1.tgz#e1ff6118896e4b22af31e5ea2f9da956adde23d8" - integrity sha512-XrjH/iEUqNl9lF2HX9YhPNV7Amntkcnpw0Bo1KkRzowNDcgSN9i0nm4Q8Oi5wupgdfPaJNMAWa61A+voD6Kmwg== +"@testing-library/user-event@^14.2.0": + version "14.2.0" + resolved "https://registry.yarnpkg.com/@testing-library/user-event/-/user-event-14.2.0.tgz#8293560f8f80a00383d6c755ec3e0b918acb1683" + integrity sha512-+hIlG4nJS6ivZrKnOP7OGsDu9Fxmryj9vCl8x0ZINtTJcCHs2zLsYif5GzuRiBF2ck5GZG2aQr7Msg+EHlnYVQ== "@tootallnate/once@1": version "1.1.2" @@ -4038,10 +4038,10 @@ "@types/node" "*" form-data "^3.0.0" -"@types/node@*", "@types/node@^17.0.31", "@types/node@^17.0.5": - version "17.0.31" - resolved "https://registry.yarnpkg.com/@types/node/-/node-17.0.31.tgz#a5bb84ecfa27eec5e1c802c6bbf8139bdb163a5d" - integrity sha512-AR0x5HbXGqkEx9CadRH3EBYx/VkiUgZIhP4wvPn/+5KIsgpNoyFaRlVe0Zlx9gRtg8fA06a9tskE2MSN7TcG4Q== +"@types/node@*", "@types/node@^17.0.32", "@types/node@^17.0.5": + version "17.0.32" + resolved "https://registry.yarnpkg.com/@types/node/-/node-17.0.32.tgz#51d59d7a90ef2d0ae961791e0900cad2393a0149" + integrity sha512-eAIcfAvhf/BkHcf4pkLJ7ECpBAhh9kcxRBpip9cTiO+hf+aJrsxYxBeS6OXvOd9WqNAJmavXVpZvY1rBjNsXmw== "@types/node@^14.0.10": version "14.18.16"