Merge pull request #233 from quantified-uncertainty/mix-distributions
Mix distributions
This commit is contained in:
commit
c6e78a1fd4
|
@ -3,7 +3,8 @@
|
||||||
"name": "squiggle",
|
"name": "squiggle",
|
||||||
"scripts": {
|
"scripts": {
|
||||||
"nodeclean": "rm -r node_modules && rm -r packages/*/node_modules",
|
"nodeclean": "rm -r node_modules && rm -r packages/*/node_modules",
|
||||||
"format:all": "prettier --write . && cd packages/squiggle-lang && yarn format"
|
"format:all": "prettier --write . && cd packages/squiggle-lang && yarn format",
|
||||||
|
"lint:all": "prettier --check . && cd packages/squiggle-lang && yarn lint:rescript"
|
||||||
},
|
},
|
||||||
"devDependencies": {
|
"devDependencies": {
|
||||||
"prettier": "^2.6.2"
|
"prettier": "^2.6.2"
|
||||||
|
|
|
@ -37,7 +37,7 @@ could be continuous, discrete or mixed.
|
||||||
<Story
|
<Story
|
||||||
name="Discrete"
|
name="Discrete"
|
||||||
args={{
|
args={{
|
||||||
squiggleString: "mm(0, 1, 3, 5, 8, 10, [0.1, 0.8, 0.5, 0.3, 0.2, 0.1])",
|
squiggleString: "mx(0, 1, 3, 5, 8, 10, [0.1, 0.8, 0.5, 0.3, 0.2, 0.1])",
|
||||||
}}
|
}}
|
||||||
>
|
>
|
||||||
{Template.bind({})}
|
{Template.bind({})}
|
||||||
|
@ -51,7 +51,7 @@ could be continuous, discrete or mixed.
|
||||||
name="Mixed"
|
name="Mixed"
|
||||||
args={{
|
args={{
|
||||||
squiggleString:
|
squiggleString:
|
||||||
"mm(0, 1, 3, 5, 8, normal(8, 1), [0.1, 0.3, 0.4, 0.35, 0.2, 0.8])",
|
"mx(0, 1, 3, 5, 8, normal(8, 1), [0.1, 0.3, 0.4, 0.35, 0.2, 0.8])",
|
||||||
}}
|
}}
|
||||||
>
|
>
|
||||||
{Template.bind({})}
|
{Template.bind({})}
|
||||||
|
|
|
@ -130,10 +130,6 @@
|
||||||
},
|
},
|
||||||
"encode": {
|
"encode": {
|
||||||
"enter": {
|
"enter": {
|
||||||
"y2": {
|
|
||||||
"scale": "yscale",
|
|
||||||
"value": 0
|
|
||||||
},
|
|
||||||
"width": {
|
"width": {
|
||||||
"value": 1
|
"value": 1
|
||||||
}
|
}
|
||||||
|
@ -146,6 +142,10 @@
|
||||||
"y": {
|
"y": {
|
||||||
"scale": "yscale",
|
"scale": "yscale",
|
||||||
"field": "y"
|
"field": "y"
|
||||||
|
},
|
||||||
|
"y2": {
|
||||||
|
"scale": "yscale",
|
||||||
|
"value": 0
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -160,7 +160,7 @@
|
||||||
"shape": {
|
"shape": {
|
||||||
"value": "circle"
|
"value": "circle"
|
||||||
},
|
},
|
||||||
"size": [{ "value": 30 }],
|
"size": [{ "value": 100 }],
|
||||||
"tooltip": {
|
"tooltip": {
|
||||||
"signal": "datum.y"
|
"signal": "datum.y"
|
||||||
}
|
}
|
||||||
|
|
4
packages/squiggle-lang/.prettierignore
Normal file
4
packages/squiggle-lang/.prettierignore
Normal file
|
@ -0,0 +1,4 @@
|
||||||
|
dist
|
||||||
|
lib
|
||||||
|
*.bs.js
|
||||||
|
*.gen.tsx
|
|
@ -90,16 +90,8 @@ describe("eval on distribution functions", () => {
|
||||||
})
|
})
|
||||||
|
|
||||||
describe("mixture", () => {
|
describe("mixture", () => {
|
||||||
testEval(
|
testEval("mx(normal(5,2), normal(10,1), normal(15, 1))", "Ok(Point Set Distribution)")
|
||||||
~skip=true,
|
testEval("mixture(normal(5,2), normal(10,1), [0.2, 0.4])", "Ok(Point Set Distribution)")
|
||||||
"mx(normal(5,2), normal(10,1), normal(15, 1))",
|
|
||||||
"Ok(Point Set Distribution)",
|
|
||||||
)
|
|
||||||
testEval(
|
|
||||||
~skip=true,
|
|
||||||
"mixture(normal(5,2), normal(10,1), [.2,, .4])",
|
|
||||||
"Ok(Point Set Distribution)",
|
|
||||||
)
|
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|
|
@ -8,6 +8,7 @@ type error =
|
||||||
| NotYetImplemented
|
| NotYetImplemented
|
||||||
| Unreachable
|
| Unreachable
|
||||||
| DistributionVerticalShiftIsInvalid
|
| DistributionVerticalShiftIsInvalid
|
||||||
|
| ArgumentError(string)
|
||||||
| Other(string)
|
| Other(string)
|
||||||
|
|
||||||
@genType
|
@genType
|
||||||
|
|
|
@ -66,6 +66,64 @@ module Helpers = {
|
||||||
dist1,
|
dist1,
|
||||||
)->runGenericOperation
|
)->runGenericOperation
|
||||||
}
|
}
|
||||||
|
let parseNumber = (args: expressionValue): Belt.Result.t<float, string> =>
|
||||||
|
switch args {
|
||||||
|
| EvNumber(x) => Ok(x)
|
||||||
|
| _ => Error("Not a number")
|
||||||
|
}
|
||||||
|
|
||||||
|
let parseNumberArray = (ags: array<expressionValue>): Belt.Result.t<array<float>, string> =>
|
||||||
|
E.A.fmap(parseNumber, ags) |> E.A.R.firstErrorOrOpen
|
||||||
|
|
||||||
|
let parseDist = (args: expressionValue): Belt.Result.t<GenericDist_Types.genericDist, string> =>
|
||||||
|
switch args {
|
||||||
|
| EvDistribution(x) => Ok(x)
|
||||||
|
| EvNumber(x) => Ok(GenericDist.fromFloat(x))
|
||||||
|
| _ => Error("Not a distribution")
|
||||||
|
}
|
||||||
|
|
||||||
|
let parseDistributionArray = (ags: array<expressionValue>): Belt.Result.t<
|
||||||
|
array<GenericDist_Types.genericDist>,
|
||||||
|
string,
|
||||||
|
> => E.A.fmap(parseDist, ags) |> E.A.R.firstErrorOrOpen
|
||||||
|
|
||||||
|
let mixtureWithGivenWeights = (
|
||||||
|
distributions: array<GenericDist_Types.genericDist>,
|
||||||
|
weights: array<float>,
|
||||||
|
): DistributionOperation.outputType =>
|
||||||
|
E.A.length(distributions) == E.A.length(weights)
|
||||||
|
? Mixture(Belt.Array.zip(distributions, weights))->runGenericOperation
|
||||||
|
: GenDistError(
|
||||||
|
ArgumentError("Error, mixture call has different number of distributions and weights"),
|
||||||
|
)
|
||||||
|
|
||||||
|
let mixtureWithDefaultWeights = (
|
||||||
|
distributions: array<GenericDist_Types.genericDist>,
|
||||||
|
): DistributionOperation.outputType => {
|
||||||
|
let length = E.A.length(distributions)
|
||||||
|
let weights = Belt.Array.make(length, 1.0 /. Belt.Int.toFloat(length))
|
||||||
|
mixtureWithGivenWeights(distributions, weights)
|
||||||
|
}
|
||||||
|
|
||||||
|
let mixture = (args: array<expressionValue>): DistributionOperation.outputType => {
|
||||||
|
switch E.A.last(args) {
|
||||||
|
| Some(EvArray(b)) => {
|
||||||
|
let weights = parseNumberArray(b)
|
||||||
|
let distributions = parseDistributionArray(
|
||||||
|
Belt.Array.slice(args, ~offset=0, ~len=E.A.length(args) - 1),
|
||||||
|
)
|
||||||
|
switch E.R.merge(distributions, weights) {
|
||||||
|
| Ok(d, w) => mixtureWithGivenWeights(d, w)
|
||||||
|
| Error(err) => GenDistError(ArgumentError(err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
| Some(EvDistribution(b)) => switch parseDistributionArray(args) {
|
||||||
|
| Ok(distributions) => mixtureWithDefaultWeights(distributions)
|
||||||
|
| Error(err) => GenDistError(ArgumentError(err))
|
||||||
|
}
|
||||||
|
| _ => GenDistError(ArgumentError("Last argument of mx must be array or distribution"))
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
module SymbolicConstructors = {
|
module SymbolicConstructors = {
|
||||||
|
@ -146,6 +204,7 @@ let dispatchToGenericOutput = (call: ExpressionValue.functionCall): option<
|
||||||
Helpers.toDistFn(Truncate(None, Some(float)), dist)
|
Helpers.toDistFn(Truncate(None, Some(float)), dist)
|
||||||
| ("truncate", [EvDistribution(dist), EvNumber(float1), EvNumber(float2)]) =>
|
| ("truncate", [EvDistribution(dist), EvNumber(float1), EvNumber(float2)]) =>
|
||||||
Helpers.toDistFn(Truncate(Some(float1), Some(float2)), dist)
|
Helpers.toDistFn(Truncate(Some(float1), Some(float2)), dist)
|
||||||
|
| ("mx" | "mixture", args) => Helpers.mixture(args)->Some
|
||||||
| ("log", [EvDistribution(a)]) =>
|
| ("log", [EvDistribution(a)]) =>
|
||||||
Helpers.twoDiststoDistFn(Algebraic, "log", a, GenericDist.fromFloat(Math.e))->Some
|
Helpers.twoDiststoDistFn(Algebraic, "log", a, GenericDist.fromFloat(Math.e))->Some
|
||||||
| ("log10", [EvDistribution(a)]) =>
|
| ("log10", [EvDistribution(a)]) =>
|
||||||
|
@ -187,7 +246,8 @@ let genericOutputToReducerValue = (o: DistributionOperation.outputType): result<
|
||||||
| GenDistError(NotYetImplemented) => Error(RETodo("Function not yet implemented"))
|
| GenDistError(NotYetImplemented) => Error(RETodo("Function not yet implemented"))
|
||||||
| GenDistError(Unreachable) => Error(RETodo("Unreachable"))
|
| GenDistError(Unreachable) => Error(RETodo("Unreachable"))
|
||||||
| GenDistError(DistributionVerticalShiftIsInvalid) =>
|
| GenDistError(DistributionVerticalShiftIsInvalid) =>
|
||||||
Error(RETodo("Distribution Vertical Shift is Invalid"))
|
Error(RETodo("Distribution Vertical Shift Is Invalid"))
|
||||||
|
| GenDistError(ArgumentError(err)) => Error(RETodo("Argument Error: " ++ err))
|
||||||
| GenDistError(Other(s)) => Error(RETodo(s))
|
| GenDistError(Other(s)) => Error(RETodo(s))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -68,15 +68,15 @@ combination of the two. The first positional arguments represent the distributio
|
||||||
to be combined, and the last argument is how much to weigh every distribution in the
|
to be combined, and the last argument is how much to weigh every distribution in the
|
||||||
combination.
|
combination.
|
||||||
|
|
||||||
<SquiggleEditor initialSquiggleString="mm(uniform(0,1), normal(1,1), [0.5, 0.5])" />
|
<SquiggleEditor initialSquiggleString="mx(uniform(0,1), normal(1,1), [0.5, 0.5])" />
|
||||||
|
|
||||||
It's possible to create discrete distributions using this method.
|
It's possible to create discrete distributions using this method.
|
||||||
|
|
||||||
<SquiggleEditor initialSquiggleString="mm(0, 1, [0.2,0.8])" />
|
<SquiggleEditor initialSquiggleString="mx(0, 1, [0.2,0.8])" />
|
||||||
|
|
||||||
As well as mixed distributions:
|
As well as mixed distributions:
|
||||||
|
|
||||||
<SquiggleEditor initialSquiggleString="mm(3, 8, 1 to 10, [0.2, 0.3, 0.5])" />
|
<SquiggleEditor initialSquiggleString="mx(3, 8, 1 to 10, [0.2, 0.3, 0.5])" />
|
||||||
|
|
||||||
## Other Functions
|
## Other Functions
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user