From 7237f2709b22e37225cb25447fbf989916b5e5f3 Mon Sep 17 00:00:00 2001 From: Ozzie Gooen Date: Sat, 20 Aug 2022 20:32:41 -0700 Subject: [PATCH] First implementation of sampleSet mixed distribution --- .../Distributions/DistributionOperation.res | 2 +- .../rescript/Distributions/GenericDist.res | 21 ++++++++++++++++--- .../rescript/Distributions/GenericDist.resi | 1 + .../SampleSetDist/SampleSetDist.res | 12 +++++++++++ .../squiggle-lang/src/rescript/Utility/E.res | 1 + 5 files changed, 33 insertions(+), 4 deletions(-) diff --git a/packages/squiggle-lang/src/rescript/Distributions/DistributionOperation.res b/packages/squiggle-lang/src/rescript/Distributions/DistributionOperation.res index 319535c1..9c61211e 100644 --- a/packages/squiggle-lang/src/rescript/Distributions/DistributionOperation.res +++ b/packages/squiggle-lang/src/rescript/Distributions/DistributionOperation.res @@ -216,7 +216,7 @@ let rec run = (~env: env, functionCallInfo: functionCallInfo): outputType => { | FromFloat(subFnName, x) => reCall(~functionCallInfo=FromFloat(subFnName, x), ()) | Mixture(dists) => dists - ->GenericDist.mixture(~scaleMultiplyFn=scaleMultiply, ~pointwiseAddFn=pointwiseAdd) + ->GenericDist.mixture(~scaleMultiplyFn=scaleMultiply, ~pointwiseAddFn=pointwiseAdd, ~env) ->E.R2.fmap(r => Dist(r)) ->OutputLocal.fromResult | FromSamples(xs) => diff --git a/packages/squiggle-lang/src/rescript/Distributions/GenericDist.res b/packages/squiggle-lang/src/rescript/Distributions/GenericDist.res index f536d54d..0676ab44 100644 --- a/packages/squiggle-lang/src/rescript/Distributions/GenericDist.res +++ b/packages/squiggle-lang/src/rescript/Distributions/GenericDist.res @@ -491,15 +491,30 @@ let pointwiseCombinationFloat = ( m->E.R2.fmap(r => DistributionTypes.PointSet(r)) } -//Note: The result should always cumulatively sum to 1. This would be good to test. -//Note: If the inputs are not normalized, this will return poor results. The weights probably refer to the post-normalized forms. It would be good to apply a catch to this. +//TODO: The result should always cumulatively sum to 1. This would be good to test. +//TODO: If the inputs are not normalized, this will return poor results. The weights probably refer to the post-normalized forms. It would be good to apply a catch to this. let mixture = ( values: array<(t, float)>, ~scaleMultiplyFn: scaleMultiplyFn, ~pointwiseAddFn: pointwiseAddFn, + ~env: env, ) => { - if E.A.length(values) == 0 { + let allValuesAreSampleSet = v => E.A.all(((t, _)) => isSampleSetSet(t), v) + + if E.A.isEmpty(values) { Error(DistributionTypes.OtherError("Mixture error: mixture must have at least 1 element")) + } else if allValuesAreSampleSet(values) { + let withSampleSetValues = values->E.A2.fmap(((value, weight)) => + switch value { + | SampleSet(sampleSet) => Ok((sampleSet, weight)) + | _ => Error("Unreachable") + } |> E.R.toExn("Mixture coding error: SampleSet expected. This should be inaccessible.") + ) + let sampleSetMixture = SampleSetDist.mixture(withSampleSetValues, env.sampleCount) + switch sampleSetMixture { + | Ok(sampleSet) => Ok(DistributionTypes.SampleSet(sampleSet)) + | Error(err) => Error(DistributionTypes.Error.sampleErrorToDistErr(err)) + } } else { let totalWeight = values->E.A2.fmap(E.Tuple2.second)->E.A.Floats.sum let properlyWeightedValues = diff --git a/packages/squiggle-lang/src/rescript/Distributions/GenericDist.resi b/packages/squiggle-lang/src/rescript/Distributions/GenericDist.resi index fd04212a..94fe44ad 100644 --- a/packages/squiggle-lang/src/rescript/Distributions/GenericDist.resi +++ b/packages/squiggle-lang/src/rescript/Distributions/GenericDist.resi @@ -81,6 +81,7 @@ let mixture: ( array<(t, float)>, ~scaleMultiplyFn: scaleMultiplyFn, ~pointwiseAddFn: pointwiseAddFn, + ~env: env, ) => result let isSymbolic: t => bool diff --git a/packages/squiggle-lang/src/rescript/Distributions/SampleSetDist/SampleSetDist.res b/packages/squiggle-lang/src/rescript/Distributions/SampleSetDist/SampleSetDist.res index dc15f7a1..f8e93df1 100644 --- a/packages/squiggle-lang/src/rescript/Distributions/SampleSetDist/SampleSetDist.res +++ b/packages/squiggle-lang/src/rescript/Distributions/SampleSetDist/SampleSetDist.res @@ -131,3 +131,15 @@ let max = t => T.get(t)->E.A.Floats.max let stdev = t => T.get(t)->E.A.Floats.stdev let variance = t => T.get(t)->E.A.Floats.variance let percentile = (t, f) => T.get(t)->E.A.Floats.percentile(f) + +let mixture = (values: array<(t, float)>, intendedLength: int) => { + let totalWeight = values->E.A2.fmap(E.Tuple2.second)->E.A.Floats.sum + values + ->E.A2.fmap(((dist, weight)) => { + let adjustedWeight = weight /. totalWeight + let samplesToGet = adjustedWeight *. E.I.toFloat(intendedLength) |> E.Float.toInt + sampleN(dist, samplesToGet) + }) + ->E.A.concatMany + ->T.make +} diff --git a/packages/squiggle-lang/src/rescript/Utility/E.res b/packages/squiggle-lang/src/rescript/Utility/E.res index 22c8c525..5930db23 100644 --- a/packages/squiggle-lang/src/rescript/Utility/E.res +++ b/packages/squiggle-lang/src/rescript/Utility/E.res @@ -220,6 +220,7 @@ module I = { let increment = n => n + 1 let decrement = n => n - 1 let toString = Js.Int.toString + let toFloat = Js.Int.toFloat } exception Assertion(string)