From 117c08bfa95beb00bcc2e639cea7ee6ffbc7fe1f Mon Sep 17 00:00:00 2001 From: Sam Nolan Date: Wed, 13 Apr 2022 12:03:04 +1000 Subject: [PATCH] Fix unweighted average of distributions --- package.json | 3 +- packages/squiggle-lang/.prettierignore | 4 ++ .../ReducerInterface_Distribution_test.res | 12 +--- .../ReducerInterface_GenericDistribution.res | 66 +++++++++++++------ 4 files changed, 53 insertions(+), 32 deletions(-) create mode 100644 packages/squiggle-lang/.prettierignore diff --git a/package.json b/package.json index a0e8fd34..9db41bf5 100644 --- a/package.json +++ b/package.json @@ -3,7 +3,8 @@ "name": "squiggle", "scripts": { "nodeclean": "rm -r node_modules && rm -r packages/*/node_modules", - "format:all": "prettier --write . && cd packages/squiggle-lang && yarn format" + "format:all": "prettier --write . && cd packages/squiggle-lang && yarn format", + "lint:all": "prettier --check . && cd packages/squiggle-lang && yarn lint:rescript" }, "devDependencies": { "prettier": "^2.6.2" diff --git a/packages/squiggle-lang/.prettierignore b/packages/squiggle-lang/.prettierignore new file mode 100644 index 00000000..30674e4d --- /dev/null +++ b/packages/squiggle-lang/.prettierignore @@ -0,0 +1,4 @@ +dist +lib +*.bs.js +*.gen.tsx diff --git a/packages/squiggle-lang/__tests__/ReducerInterface/ReducerInterface_Distribution_test.res b/packages/squiggle-lang/__tests__/ReducerInterface/ReducerInterface_Distribution_test.res index cc17593a..0a131d93 100644 --- a/packages/squiggle-lang/__tests__/ReducerInterface/ReducerInterface_Distribution_test.res +++ b/packages/squiggle-lang/__tests__/ReducerInterface/ReducerInterface_Distribution_test.res @@ -90,16 +90,8 @@ describe("eval on distribution functions", () => { }) describe("mixture", () => { - testEval( - ~skip=true, - "mx(normal(5,2), normal(10,1), normal(15, 1))", - "Ok(Point Set Distribution)", - ) - testEval( - ~skip=true, - "mixture(normal(5,2), normal(10,1), [.2,, .4])", - "Ok(Point Set Distribution)", - ) + testEval("mx(normal(5,2), normal(10,1), normal(15, 1))", "Ok(Point Set Distribution)") + testEval("mixture(normal(5,2), normal(10,1), [0.2, 0.4])", "Ok(Point Set Distribution)") }) }) diff --git a/packages/squiggle-lang/src/rescript/ReducerInterface/ReducerInterface_GenericDistribution.res b/packages/squiggle-lang/src/rescript/ReducerInterface/ReducerInterface_GenericDistribution.res index 62835730..918c4b70 100644 --- a/packages/squiggle-lang/src/rescript/ReducerInterface/ReducerInterface_GenericDistribution.res +++ b/packages/squiggle-lang/src/rescript/ReducerInterface/ReducerInterface_GenericDistribution.res @@ -66,36 +66,62 @@ module Helpers = { dist1, )->runGenericOperation } - let parseNumber = (args: expressionValue) : Belt.Result.t => + let parseNumber = (args: expressionValue): Belt.Result.t => switch args { | EvNumber(x) => Ok(x) | _ => Error("Not a number") } - let parseNumberArray = (ags: array) : Belt.Result.t, string> => E.A.fmap(parseNumber, ags) |> E.A.R.firstErrorOrOpen + let parseNumberArray = (ags: array): Belt.Result.t, string> => + E.A.fmap(parseNumber, ags) |> E.A.R.firstErrorOrOpen - let parseDist = (args: expressionValue): Belt.Result.t => + let parseDist = (args: expressionValue): Belt.Result.t => switch args { | EvDistribution(x) => Ok(x) | EvNumber(x) => Ok(GenericDist.fromFloat(x)) | _ => Error("Not a distribution") } - let parseDistributionArray = (ags: array) : Belt.Result.t, string> => E.A.fmap(parseDist, ags) |> E.A.R.firstErrorOrOpen + let parseDistributionArray = (ags: array): Belt.Result.t< + array, + string, + > => E.A.fmap(parseDist, ags) |> E.A.R.firstErrorOrOpen - let mixture = (args : array): DistributionOperation.outputType => { - let givenWeights = E.A.last(args) - let calculatedWeights = - switch givenWeights { - | Some(EvArray(b)) => parseNumberArray(b) - | None => - Ok(Belt.Array.make(E.A.length(args), 1.0 /. Belt.Int.toFloat(E.A.length(args)))) - | _ => Error("Last argument of mx must be array") + let mixtureWithGivenWeights = ( + distributions: array, + weights: array, + ): DistributionOperation.outputType => + E.A.length(distributions) == E.A.length(weights) + ? Mixture(Belt.Array.zip(distributions, weights))->runGenericOperation + : GenDistError( + ArgumentError("Error, mixture call has different number of distributions and weights"), + ) + + let mixtureWithDefaultWeights = ( + distributions: array, + ): DistributionOperation.outputType => { + let length = E.A.length(distributions) + let weights = Belt.Array.make(length, 1.0 /. Belt.Int.toFloat(length)) + mixtureWithGivenWeights(distributions, weights) + } + + let mixture = (args: array): DistributionOperation.outputType => { + switch E.A.last(args) { + | Some(EvArray(b)) => { + let weights = parseNumberArray(b) + let distributions = parseDistributionArray( + Belt.Array.slice(args, ~offset=0, ~len=E.A.length(args) - 1), + ) + switch E.R.merge(distributions, weights) { + | Ok(d, w) => mixtureWithGivenWeights(d, w) + | Error(err) => GenDistError(ArgumentError(err)) } - switch (parseDistributionArray(Belt.Array.slice(args, ~offset=0, ~len=Belt.Array.length(args)-1)), calculatedWeights) { - | (Ok(distArray), Ok(w)) => Mixture(Belt.Array.zip(distArray, w)) -> runGenericOperation - | (Error(err), _) => GenDistError(ArgumentError(err)) - | (_, Error(err)) => GenDistError(ArgumentError(err)) + } + | Some(EvDistribution(b)) => switch parseDistributionArray(args) { + | Ok(distributions) => mixtureWithDefaultWeights(distributions) + | Error(err) => GenDistError(ArgumentError(err)) + } + | _ => GenDistError(ArgumentError("Last argument of mx must be array or distribution")) } } } @@ -178,8 +204,7 @@ let dispatchToGenericOutput = (call: ExpressionValue.functionCall): option< Helpers.toDistFn(Truncate(None, Some(float)), dist) | ("truncate", [EvDistribution(dist), EvNumber(float1), EvNumber(float2)]) => Helpers.toDistFn(Truncate(Some(float1), Some(float2)), dist) - | (("mx" | "mixture"), args) => - Helpers.mixture(args) -> Some + | ("mx" | "mixture", args) => Helpers.mixture(args)->Some | ("log", [EvDistribution(a)]) => Helpers.twoDiststoDistFn(Algebraic, "log", a, GenericDist.fromFloat(Math.e))->Some | ("log10", [EvDistribution(a)]) => @@ -221,9 +246,8 @@ let genericOutputToReducerValue = (o: DistributionOperation.outputType): result< | GenDistError(NotYetImplemented) => Error(RETodo("Function not yet implemented")) | GenDistError(Unreachable) => Error(RETodo("Unreachable")) | GenDistError(DistributionVerticalShiftIsInvalid) => - Error(RETodo("Distribution Vertical Shift Is Invalid")) - | GenDistError(ArgumentError(err)) => - Error(RETodo("Argument Error: " ++ err)) + Error(RETodo("Distribution Vertical Shift Is Invalid")) + | GenDistError(ArgumentError(err)) => Error(RETodo("Argument Error: " ++ err)) | GenDistError(Other(s)) => Error(RETodo(s)) }