Merge pull request #233 from quantified-uncertainty/mix-distributions
Mix distributions
This commit is contained in:
		
						commit
						c6e78a1fd4
					
				|  | @ -3,7 +3,8 @@ | |||
|   "name": "squiggle", | ||||
|   "scripts": { | ||||
|     "nodeclean": "rm -r node_modules && rm -r packages/*/node_modules", | ||||
|     "format:all": "prettier --write . && cd packages/squiggle-lang && yarn format" | ||||
|     "format:all": "prettier --write . && cd packages/squiggle-lang && yarn format", | ||||
|     "lint:all": "prettier --check . && cd packages/squiggle-lang && yarn lint:rescript" | ||||
|   }, | ||||
|   "devDependencies": { | ||||
|     "prettier": "^2.6.2" | ||||
|  |  | |||
|  | @ -37,7 +37,7 @@ could be continuous, discrete or mixed. | |||
|   <Story | ||||
|     name="Discrete" | ||||
|     args={{ | ||||
|       squiggleString: "mm(0, 1, 3, 5, 8, 10, [0.1, 0.8, 0.5, 0.3, 0.2, 0.1])", | ||||
|       squiggleString: "mx(0, 1, 3, 5, 8, 10, [0.1, 0.8, 0.5, 0.3, 0.2, 0.1])", | ||||
|     }} | ||||
|   > | ||||
|     {Template.bind({})} | ||||
|  | @ -51,7 +51,7 @@ could be continuous, discrete or mixed. | |||
|     name="Mixed" | ||||
|     args={{ | ||||
|       squiggleString: | ||||
|         "mm(0, 1, 3, 5, 8, normal(8, 1), [0.1, 0.3, 0.4, 0.35, 0.2, 0.8])", | ||||
|         "mx(0, 1, 3, 5, 8, normal(8, 1), [0.1, 0.3, 0.4, 0.35, 0.2, 0.8])", | ||||
|     }} | ||||
|   > | ||||
|     {Template.bind({})} | ||||
|  |  | |||
|  | @ -130,10 +130,6 @@ | |||
|       }, | ||||
|       "encode": { | ||||
|         "enter": { | ||||
|           "y2": { | ||||
|             "scale": "yscale", | ||||
|             "value": 0 | ||||
|           }, | ||||
|           "width": { | ||||
|             "value": 1 | ||||
|           } | ||||
|  | @ -146,6 +142,10 @@ | |||
|           "y": { | ||||
|             "scale": "yscale", | ||||
|             "field": "y" | ||||
|           }, | ||||
|           "y2": { | ||||
|             "scale": "yscale", | ||||
|             "value": 0 | ||||
|           } | ||||
|         } | ||||
|       } | ||||
|  | @ -160,7 +160,7 @@ | |||
|           "shape": { | ||||
|             "value": "circle" | ||||
|           }, | ||||
|           "size": [{ "value": 30 }], | ||||
|           "size": [{ "value": 100 }], | ||||
|           "tooltip": { | ||||
|             "signal": "datum.y" | ||||
|           } | ||||
|  |  | |||
							
								
								
									
										4
									
								
								packages/squiggle-lang/.prettierignore
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										4
									
								
								packages/squiggle-lang/.prettierignore
									
									
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,4 @@ | |||
| dist | ||||
| lib | ||||
| *.bs.js | ||||
| *.gen.tsx | ||||
|  | @ -90,16 +90,8 @@ describe("eval on distribution functions", () => { | |||
|   }) | ||||
| 
 | ||||
|   describe("mixture", () => { | ||||
|     testEval( | ||||
|       ~skip=true, | ||||
|       "mx(normal(5,2), normal(10,1), normal(15, 1))", | ||||
|       "Ok(Point Set Distribution)", | ||||
|     ) | ||||
|     testEval( | ||||
|       ~skip=true, | ||||
|       "mixture(normal(5,2), normal(10,1), [.2,, .4])", | ||||
|       "Ok(Point Set Distribution)", | ||||
|     ) | ||||
|     testEval("mx(normal(5,2), normal(10,1), normal(15, 1))", "Ok(Point Set Distribution)") | ||||
|     testEval("mixture(normal(5,2), normal(10,1), [0.2, 0.4])", "Ok(Point Set Distribution)") | ||||
|   }) | ||||
| }) | ||||
| 
 | ||||
|  |  | |||
|  | @ -8,6 +8,7 @@ type error = | |||
|   | NotYetImplemented | ||||
|   | Unreachable | ||||
|   | DistributionVerticalShiftIsInvalid | ||||
|   | ArgumentError(string) | ||||
|   | Other(string) | ||||
| 
 | ||||
| @genType | ||||
|  |  | |||
|  | @ -66,6 +66,64 @@ module Helpers = { | |||
|       dist1, | ||||
|     )->runGenericOperation | ||||
|   } | ||||
|   let parseNumber = (args: expressionValue): Belt.Result.t<float, string> => | ||||
|     switch args { | ||||
|     | EvNumber(x) => Ok(x) | ||||
|     | _ => Error("Not a number") | ||||
|     } | ||||
| 
 | ||||
|   let parseNumberArray = (ags: array<expressionValue>): Belt.Result.t<array<float>, string> => | ||||
|     E.A.fmap(parseNumber, ags) |> E.A.R.firstErrorOrOpen | ||||
| 
 | ||||
|   let parseDist = (args: expressionValue): Belt.Result.t<GenericDist_Types.genericDist, string> => | ||||
|     switch args { | ||||
|     | EvDistribution(x) => Ok(x) | ||||
|     | EvNumber(x) => Ok(GenericDist.fromFloat(x)) | ||||
|     | _ => Error("Not a distribution") | ||||
|     } | ||||
| 
 | ||||
|   let parseDistributionArray = (ags: array<expressionValue>): Belt.Result.t< | ||||
|     array<GenericDist_Types.genericDist>, | ||||
|     string, | ||||
|   > => E.A.fmap(parseDist, ags) |> E.A.R.firstErrorOrOpen | ||||
| 
 | ||||
|   let mixtureWithGivenWeights = ( | ||||
|     distributions: array<GenericDist_Types.genericDist>, | ||||
|     weights: array<float>, | ||||
|   ): DistributionOperation.outputType => | ||||
|     E.A.length(distributions) == E.A.length(weights) | ||||
|       ? Mixture(Belt.Array.zip(distributions, weights))->runGenericOperation | ||||
|       : GenDistError( | ||||
|           ArgumentError("Error, mixture call has different number of distributions and weights"), | ||||
|         ) | ||||
| 
 | ||||
|   let mixtureWithDefaultWeights = ( | ||||
|     distributions: array<GenericDist_Types.genericDist>, | ||||
|   ): DistributionOperation.outputType => { | ||||
|     let length = E.A.length(distributions) | ||||
|     let weights = Belt.Array.make(length, 1.0 /. Belt.Int.toFloat(length)) | ||||
|     mixtureWithGivenWeights(distributions, weights) | ||||
|   } | ||||
| 
 | ||||
|   let mixture = (args: array<expressionValue>): DistributionOperation.outputType => { | ||||
|     switch E.A.last(args) { | ||||
|     | Some(EvArray(b)) => { | ||||
|         let weights = parseNumberArray(b) | ||||
|         let distributions = parseDistributionArray( | ||||
|           Belt.Array.slice(args, ~offset=0, ~len=E.A.length(args) - 1), | ||||
|         ) | ||||
|         switch E.R.merge(distributions, weights) { | ||||
|         | Ok(d, w) => mixtureWithGivenWeights(d, w) | ||||
|         | Error(err) => GenDistError(ArgumentError(err)) | ||||
|         } | ||||
|       } | ||||
|     | Some(EvDistribution(b)) => switch parseDistributionArray(args) { | ||||
|       | Ok(distributions) => mixtureWithDefaultWeights(distributions) | ||||
|       | Error(err) => GenDistError(ArgumentError(err)) | ||||
|       } | ||||
|     | _ => GenDistError(ArgumentError("Last argument of mx must be array or distribution")) | ||||
|     } | ||||
|   } | ||||
| } | ||||
| 
 | ||||
| module SymbolicConstructors = { | ||||
|  | @ -146,6 +204,7 @@ let dispatchToGenericOutput = (call: ExpressionValue.functionCall): option< | |||
|     Helpers.toDistFn(Truncate(None, Some(float)), dist) | ||||
|   | ("truncate", [EvDistribution(dist), EvNumber(float1), EvNumber(float2)]) => | ||||
|     Helpers.toDistFn(Truncate(Some(float1), Some(float2)), dist) | ||||
|   | ("mx" | "mixture", args) => Helpers.mixture(args)->Some | ||||
|   | ("log", [EvDistribution(a)]) => | ||||
|     Helpers.twoDiststoDistFn(Algebraic, "log", a, GenericDist.fromFloat(Math.e))->Some | ||||
|   | ("log10", [EvDistribution(a)]) => | ||||
|  | @ -187,7 +246,8 @@ let genericOutputToReducerValue = (o: DistributionOperation.outputType): result< | |||
|   | GenDistError(NotYetImplemented) => Error(RETodo("Function not yet implemented")) | ||||
|   | GenDistError(Unreachable) => Error(RETodo("Unreachable")) | ||||
|   | GenDistError(DistributionVerticalShiftIsInvalid) => | ||||
|     Error(RETodo("Distribution Vertical Shift is Invalid")) | ||||
|     Error(RETodo("Distribution Vertical Shift Is Invalid")) | ||||
|   | GenDistError(ArgumentError(err)) => Error(RETodo("Argument Error: " ++ err)) | ||||
|   | GenDistError(Other(s)) => Error(RETodo(s)) | ||||
|   } | ||||
| 
 | ||||
|  |  | |||
|  | @ -68,15 +68,15 @@ combination of the two. The first positional arguments represent the distributio | |||
| to be combined, and the last argument is how much to weigh every distribution in the | ||||
| combination. | ||||
| 
 | ||||
| <SquiggleEditor initialSquiggleString="mm(uniform(0,1), normal(1,1), [0.5, 0.5])" /> | ||||
| <SquiggleEditor initialSquiggleString="mx(uniform(0,1), normal(1,1), [0.5, 0.5])" /> | ||||
| 
 | ||||
| It's possible to create discrete distributions using this method. | ||||
| 
 | ||||
| <SquiggleEditor initialSquiggleString="mm(0, 1, [0.2,0.8])" /> | ||||
| <SquiggleEditor initialSquiggleString="mx(0, 1, [0.2,0.8])" /> | ||||
| 
 | ||||
| As well as mixed distributions: | ||||
| 
 | ||||
| <SquiggleEditor initialSquiggleString="mm(3, 8, 1 to 10, [0.2, 0.3, 0.5])" /> | ||||
| <SquiggleEditor initialSquiggleString="mx(3, 8, 1 to 10, [0.2, 0.3, 0.5])" /> | ||||
| 
 | ||||
| ## Other Functions | ||||
| 
 | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue
	
	Block a user