Fix unweighted average of distributions
This commit is contained in:
parent
38135f0c81
commit
117c08bfa9
|
@ -3,7 +3,8 @@
|
|||
"name": "squiggle",
|
||||
"scripts": {
|
||||
"nodeclean": "rm -r node_modules && rm -r packages/*/node_modules",
|
||||
"format:all": "prettier --write . && cd packages/squiggle-lang && yarn format"
|
||||
"format:all": "prettier --write . && cd packages/squiggle-lang && yarn format",
|
||||
"lint:all": "prettier --check . && cd packages/squiggle-lang && yarn lint:rescript"
|
||||
},
|
||||
"devDependencies": {
|
||||
"prettier": "^2.6.2"
|
||||
|
|
4
packages/squiggle-lang/.prettierignore
Normal file
4
packages/squiggle-lang/.prettierignore
Normal file
|
@ -0,0 +1,4 @@
|
|||
dist
|
||||
lib
|
||||
*.bs.js
|
||||
*.gen.tsx
|
|
@ -90,16 +90,8 @@ describe("eval on distribution functions", () => {
|
|||
})
|
||||
|
||||
describe("mixture", () => {
|
||||
testEval(
|
||||
~skip=true,
|
||||
"mx(normal(5,2), normal(10,1), normal(15, 1))",
|
||||
"Ok(Point Set Distribution)",
|
||||
)
|
||||
testEval(
|
||||
~skip=true,
|
||||
"mixture(normal(5,2), normal(10,1), [.2,, .4])",
|
||||
"Ok(Point Set Distribution)",
|
||||
)
|
||||
testEval("mx(normal(5,2), normal(10,1), normal(15, 1))", "Ok(Point Set Distribution)")
|
||||
testEval("mixture(normal(5,2), normal(10,1), [0.2, 0.4])", "Ok(Point Set Distribution)")
|
||||
})
|
||||
})
|
||||
|
||||
|
|
|
@ -72,7 +72,8 @@ module Helpers = {
|
|||
| _ => Error("Not a number")
|
||||
}
|
||||
|
||||
let parseNumberArray = (ags: array<expressionValue>) : Belt.Result.t<array<float>, string> => E.A.fmap(parseNumber, ags) |> E.A.R.firstErrorOrOpen
|
||||
let parseNumberArray = (ags: array<expressionValue>): Belt.Result.t<array<float>, string> =>
|
||||
E.A.fmap(parseNumber, ags) |> E.A.R.firstErrorOrOpen
|
||||
|
||||
let parseDist = (args: expressionValue): Belt.Result.t<GenericDist_Types.genericDist, string> =>
|
||||
switch args {
|
||||
|
@ -81,21 +82,46 @@ module Helpers = {
|
|||
| _ => Error("Not a distribution")
|
||||
}
|
||||
|
||||
let parseDistributionArray = (ags: array<expressionValue>) : Belt.Result.t<array<GenericDist_Types.genericDist>, string> => E.A.fmap(parseDist, ags) |> E.A.R.firstErrorOrOpen
|
||||
let parseDistributionArray = (ags: array<expressionValue>): Belt.Result.t<
|
||||
array<GenericDist_Types.genericDist>,
|
||||
string,
|
||||
> => E.A.fmap(parseDist, ags) |> E.A.R.firstErrorOrOpen
|
||||
|
||||
let mixtureWithGivenWeights = (
|
||||
distributions: array<GenericDist_Types.genericDist>,
|
||||
weights: array<float>,
|
||||
): DistributionOperation.outputType =>
|
||||
E.A.length(distributions) == E.A.length(weights)
|
||||
? Mixture(Belt.Array.zip(distributions, weights))->runGenericOperation
|
||||
: GenDistError(
|
||||
ArgumentError("Error, mixture call has different number of distributions and weights"),
|
||||
)
|
||||
|
||||
let mixtureWithDefaultWeights = (
|
||||
distributions: array<GenericDist_Types.genericDist>,
|
||||
): DistributionOperation.outputType => {
|
||||
let length = E.A.length(distributions)
|
||||
let weights = Belt.Array.make(length, 1.0 /. Belt.Int.toFloat(length))
|
||||
mixtureWithGivenWeights(distributions, weights)
|
||||
}
|
||||
|
||||
let mixture = (args: array<expressionValue>): DistributionOperation.outputType => {
|
||||
let givenWeights = E.A.last(args)
|
||||
let calculatedWeights =
|
||||
switch givenWeights {
|
||||
| Some(EvArray(b)) => parseNumberArray(b)
|
||||
| None =>
|
||||
Ok(Belt.Array.make(E.A.length(args), 1.0 /. Belt.Int.toFloat(E.A.length(args))))
|
||||
| _ => Error("Last argument of mx must be array")
|
||||
switch E.A.last(args) {
|
||||
| Some(EvArray(b)) => {
|
||||
let weights = parseNumberArray(b)
|
||||
let distributions = parseDistributionArray(
|
||||
Belt.Array.slice(args, ~offset=0, ~len=E.A.length(args) - 1),
|
||||
)
|
||||
switch E.R.merge(distributions, weights) {
|
||||
| Ok(d, w) => mixtureWithGivenWeights(d, w)
|
||||
| Error(err) => GenDistError(ArgumentError(err))
|
||||
}
|
||||
switch (parseDistributionArray(Belt.Array.slice(args, ~offset=0, ~len=Belt.Array.length(args)-1)), calculatedWeights) {
|
||||
| (Ok(distArray), Ok(w)) => Mixture(Belt.Array.zip(distArray, w)) -> runGenericOperation
|
||||
| (Error(err), _) => GenDistError(ArgumentError(err))
|
||||
| (_, Error(err)) => GenDistError(ArgumentError(err))
|
||||
}
|
||||
| Some(EvDistribution(b)) => switch parseDistributionArray(args) {
|
||||
| Ok(distributions) => mixtureWithDefaultWeights(distributions)
|
||||
| Error(err) => GenDistError(ArgumentError(err))
|
||||
}
|
||||
| _ => GenDistError(ArgumentError("Last argument of mx must be array or distribution"))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -178,8 +204,7 @@ let dispatchToGenericOutput = (call: ExpressionValue.functionCall): option<
|
|||
Helpers.toDistFn(Truncate(None, Some(float)), dist)
|
||||
| ("truncate", [EvDistribution(dist), EvNumber(float1), EvNumber(float2)]) =>
|
||||
Helpers.toDistFn(Truncate(Some(float1), Some(float2)), dist)
|
||||
| (("mx" | "mixture"), args) =>
|
||||
Helpers.mixture(args) -> Some
|
||||
| ("mx" | "mixture", args) => Helpers.mixture(args)->Some
|
||||
| ("log", [EvDistribution(a)]) =>
|
||||
Helpers.twoDiststoDistFn(Algebraic, "log", a, GenericDist.fromFloat(Math.e))->Some
|
||||
| ("log10", [EvDistribution(a)]) =>
|
||||
|
@ -222,8 +247,7 @@ let genericOutputToReducerValue = (o: DistributionOperation.outputType): result<
|
|||
| GenDistError(Unreachable) => Error(RETodo("Unreachable"))
|
||||
| GenDistError(DistributionVerticalShiftIsInvalid) =>
|
||||
Error(RETodo("Distribution Vertical Shift Is Invalid"))
|
||||
| GenDistError(ArgumentError(err)) =>
|
||||
Error(RETodo("Argument Error: " ++ err))
|
||||
| GenDistError(ArgumentError(err)) => Error(RETodo("Argument Error: " ++ err))
|
||||
| GenDistError(Other(s)) => Error(RETodo(s))
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user