Merge pull request #233 from quantified-uncertainty/mix-distributions

Mix distributions
This commit is contained in:
Ozzie Gooen 2022-04-12 22:19:35 -04:00 committed by GitHub
commit c6e78a1fd4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 80 additions and 22 deletions

View File

@ -3,7 +3,8 @@
"name": "squiggle", "name": "squiggle",
"scripts": { "scripts": {
"nodeclean": "rm -r node_modules && rm -r packages/*/node_modules", "nodeclean": "rm -r node_modules && rm -r packages/*/node_modules",
"format:all": "prettier --write . && cd packages/squiggle-lang && yarn format" "format:all": "prettier --write . && cd packages/squiggle-lang && yarn format",
"lint:all": "prettier --check . && cd packages/squiggle-lang && yarn lint:rescript"
}, },
"devDependencies": { "devDependencies": {
"prettier": "^2.6.2" "prettier": "^2.6.2"

View File

@ -37,7 +37,7 @@ could be continuous, discrete or mixed.
<Story <Story
name="Discrete" name="Discrete"
args={{ args={{
squiggleString: "mm(0, 1, 3, 5, 8, 10, [0.1, 0.8, 0.5, 0.3, 0.2, 0.1])", squiggleString: "mx(0, 1, 3, 5, 8, 10, [0.1, 0.8, 0.5, 0.3, 0.2, 0.1])",
}} }}
> >
{Template.bind({})} {Template.bind({})}
@ -51,7 +51,7 @@ could be continuous, discrete or mixed.
name="Mixed" name="Mixed"
args={{ args={{
squiggleString: squiggleString:
"mm(0, 1, 3, 5, 8, normal(8, 1), [0.1, 0.3, 0.4, 0.35, 0.2, 0.8])", "mx(0, 1, 3, 5, 8, normal(8, 1), [0.1, 0.3, 0.4, 0.35, 0.2, 0.8])",
}} }}
> >
{Template.bind({})} {Template.bind({})}

View File

@ -130,10 +130,6 @@
}, },
"encode": { "encode": {
"enter": { "enter": {
"y2": {
"scale": "yscale",
"value": 0
},
"width": { "width": {
"value": 1 "value": 1
} }
@ -146,6 +142,10 @@
"y": { "y": {
"scale": "yscale", "scale": "yscale",
"field": "y" "field": "y"
},
"y2": {
"scale": "yscale",
"value": 0
} }
} }
} }
@ -160,7 +160,7 @@
"shape": { "shape": {
"value": "circle" "value": "circle"
}, },
"size": [{ "value": 30 }], "size": [{ "value": 100 }],
"tooltip": { "tooltip": {
"signal": "datum.y" "signal": "datum.y"
} }

View File

@ -0,0 +1,4 @@
dist
lib
*.bs.js
*.gen.tsx

View File

@ -90,16 +90,8 @@ describe("eval on distribution functions", () => {
}) })
describe("mixture", () => { describe("mixture", () => {
testEval( testEval("mx(normal(5,2), normal(10,1), normal(15, 1))", "Ok(Point Set Distribution)")
~skip=true, testEval("mixture(normal(5,2), normal(10,1), [0.2, 0.4])", "Ok(Point Set Distribution)")
"mx(normal(5,2), normal(10,1), normal(15, 1))",
"Ok(Point Set Distribution)",
)
testEval(
~skip=true,
"mixture(normal(5,2), normal(10,1), [.2,, .4])",
"Ok(Point Set Distribution)",
)
}) })
}) })

View File

@ -8,6 +8,7 @@ type error =
| NotYetImplemented | NotYetImplemented
| Unreachable | Unreachable
| DistributionVerticalShiftIsInvalid | DistributionVerticalShiftIsInvalid
| ArgumentError(string)
| Other(string) | Other(string)
@genType @genType

View File

@ -66,6 +66,64 @@ module Helpers = {
dist1, dist1,
)->runGenericOperation )->runGenericOperation
} }
let parseNumber = (args: expressionValue): Belt.Result.t<float, string> =>
switch args {
| EvNumber(x) => Ok(x)
| _ => Error("Not a number")
}
let parseNumberArray = (ags: array<expressionValue>): Belt.Result.t<array<float>, string> =>
E.A.fmap(parseNumber, ags) |> E.A.R.firstErrorOrOpen
let parseDist = (args: expressionValue): Belt.Result.t<GenericDist_Types.genericDist, string> =>
switch args {
| EvDistribution(x) => Ok(x)
| EvNumber(x) => Ok(GenericDist.fromFloat(x))
| _ => Error("Not a distribution")
}
let parseDistributionArray = (ags: array<expressionValue>): Belt.Result.t<
array<GenericDist_Types.genericDist>,
string,
> => E.A.fmap(parseDist, ags) |> E.A.R.firstErrorOrOpen
let mixtureWithGivenWeights = (
distributions: array<GenericDist_Types.genericDist>,
weights: array<float>,
): DistributionOperation.outputType =>
E.A.length(distributions) == E.A.length(weights)
? Mixture(Belt.Array.zip(distributions, weights))->runGenericOperation
: GenDistError(
ArgumentError("Error, mixture call has different number of distributions and weights"),
)
let mixtureWithDefaultWeights = (
distributions: array<GenericDist_Types.genericDist>,
): DistributionOperation.outputType => {
let length = E.A.length(distributions)
let weights = Belt.Array.make(length, 1.0 /. Belt.Int.toFloat(length))
mixtureWithGivenWeights(distributions, weights)
}
let mixture = (args: array<expressionValue>): DistributionOperation.outputType => {
switch E.A.last(args) {
| Some(EvArray(b)) => {
let weights = parseNumberArray(b)
let distributions = parseDistributionArray(
Belt.Array.slice(args, ~offset=0, ~len=E.A.length(args) - 1),
)
switch E.R.merge(distributions, weights) {
| Ok(d, w) => mixtureWithGivenWeights(d, w)
| Error(err) => GenDistError(ArgumentError(err))
}
}
| Some(EvDistribution(b)) => switch parseDistributionArray(args) {
| Ok(distributions) => mixtureWithDefaultWeights(distributions)
| Error(err) => GenDistError(ArgumentError(err))
}
| _ => GenDistError(ArgumentError("Last argument of mx must be array or distribution"))
}
}
} }
module SymbolicConstructors = { module SymbolicConstructors = {
@ -146,6 +204,7 @@ let dispatchToGenericOutput = (call: ExpressionValue.functionCall): option<
Helpers.toDistFn(Truncate(None, Some(float)), dist) Helpers.toDistFn(Truncate(None, Some(float)), dist)
| ("truncate", [EvDistribution(dist), EvNumber(float1), EvNumber(float2)]) => | ("truncate", [EvDistribution(dist), EvNumber(float1), EvNumber(float2)]) =>
Helpers.toDistFn(Truncate(Some(float1), Some(float2)), dist) Helpers.toDistFn(Truncate(Some(float1), Some(float2)), dist)
| ("mx" | "mixture", args) => Helpers.mixture(args)->Some
| ("log", [EvDistribution(a)]) => | ("log", [EvDistribution(a)]) =>
Helpers.twoDiststoDistFn(Algebraic, "log", a, GenericDist.fromFloat(Math.e))->Some Helpers.twoDiststoDistFn(Algebraic, "log", a, GenericDist.fromFloat(Math.e))->Some
| ("log10", [EvDistribution(a)]) => | ("log10", [EvDistribution(a)]) =>
@ -187,7 +246,8 @@ let genericOutputToReducerValue = (o: DistributionOperation.outputType): result<
| GenDistError(NotYetImplemented) => Error(RETodo("Function not yet implemented")) | GenDistError(NotYetImplemented) => Error(RETodo("Function not yet implemented"))
| GenDistError(Unreachable) => Error(RETodo("Unreachable")) | GenDistError(Unreachable) => Error(RETodo("Unreachable"))
| GenDistError(DistributionVerticalShiftIsInvalid) => | GenDistError(DistributionVerticalShiftIsInvalid) =>
Error(RETodo("Distribution Vertical Shift is Invalid")) Error(RETodo("Distribution Vertical Shift Is Invalid"))
| GenDistError(ArgumentError(err)) => Error(RETodo("Argument Error: " ++ err))
| GenDistError(Other(s)) => Error(RETodo(s)) | GenDistError(Other(s)) => Error(RETodo(s))
} }

View File

@ -68,15 +68,15 @@ combination of the two. The first positional arguments represent the distributio
to be combined, and the last argument is how much to weigh every distribution in the to be combined, and the last argument is how much to weigh every distribution in the
combination. combination.
<SquiggleEditor initialSquiggleString="mm(uniform(0,1), normal(1,1), [0.5, 0.5])" /> <SquiggleEditor initialSquiggleString="mx(uniform(0,1), normal(1,1), [0.5, 0.5])" />
It's possible to create discrete distributions using this method. It's possible to create discrete distributions using this method.
<SquiggleEditor initialSquiggleString="mm(0, 1, [0.2,0.8])" /> <SquiggleEditor initialSquiggleString="mx(0, 1, [0.2,0.8])" />
As well as mixed distributions: As well as mixed distributions:
<SquiggleEditor initialSquiggleString="mm(3, 8, 1 to 10, [0.2, 0.3, 0.5])" /> <SquiggleEditor initialSquiggleString="mx(3, 8, 1 to 10, [0.2, 0.3, 0.5])" />
## Other Functions ## Other Functions