diff --git a/packages/squiggle-lang/__tests__/Distributions/SampleSetDist_test.res b/packages/squiggle-lang/__tests__/Distributions/SampleSetDist_test.res index 1c430f3d..edfcb6ef 100644 --- a/packages/squiggle-lang/__tests__/Distributions/SampleSetDist_test.res +++ b/packages/squiggle-lang/__tests__/Distributions/SampleSetDist_test.res @@ -4,20 +4,34 @@ open TestHelpers describe("Continuous and discrete splits", () => { makeTest( "splits (1)", - SampleSetDist_ToPointSet.Internals.T.splitContinuousAndDiscrete([1.432, 1.33455, 2.0]), - ([1.432, 1.33455, 2.0], E.FloatFloatMap.empty()), + E.A.Sorted.Floats.splitContinuousAndDiscreteForMinWeight([1.432, 1.33455, 2.0], 2), + ([1.33455, 1.432, 2.0], E.FloatFloatMap.empty()), ) makeTest( "splits (2)", - SampleSetDist_ToPointSet.Internals.T.splitContinuousAndDiscrete([ - 1.432, - 1.33455, - 2.0, - 2.0, - 2.0, - 2.0, - ]) |> (((c, disc)) => (c, disc |> E.FloatFloatMap.toArray)), - ([1.432, 1.33455], [(2.0, 4.0)]), + E.A.Sorted.Floats.splitContinuousAndDiscreteForMinWeight( + [1.432, 1.33455, 2.0, 2.0, 2.0, 2.0], + 2, + ) |> (((c, disc)) => (c, disc |> E.FloatFloatMap.toArray)), + ([1.33455, 1.432], [(2.0, 4.0)]), + ) + + makeTest( + "splits (3)", + E.A.Sorted.Floats.splitContinuousAndDiscreteForMinWeight( + [1.432, 1.33455, 2.0, 2.0, 3.5, 3.5, 3.5], + 3, + ) |> (((c, disc)) => (c, disc |> E.FloatFloatMap.toArray)), + ([1.33455, 1.432, 2.0, 2.0], [(3.5, 3.0)]), + ) + + makeTest( + "splits (3)", + E.A.Sorted.Floats.splitContinuousAndDiscreteForMinWeight( + [1.432, 1.33455, 2.0, 2.0, 3.5, 3.5, 3.5], + 5, + ) |> (((c, disc)) => (c, disc |> E.FloatFloatMap.toArray)), + ([1.33455, 1.432, 2.0, 2.0, 3.5, 3.5, 3.5], []), ) let makeDuplicatedArray = count => { @@ -26,14 +40,16 @@ describe("Continuous and discrete splits", () => { E.A.concatMany([sorted, sorted, sorted, sorted]) |> Belt.SortArray.stableSortBy(_, compare) } - let (_, discrete1) = SampleSetDist_ToPointSet.Internals.T.splitContinuousAndDiscrete( + let (_, discrete1) = E.A.Sorted.Floats.splitContinuousAndDiscreteForMinWeight( makeDuplicatedArray(10), + 2, ) let toArr1 = discrete1 |> E.FloatFloatMap.toArray makeTest("splitMedium at count=10", toArr1 |> Belt.Array.length, 10) - let (_c, discrete2) = SampleSetDist_ToPointSet.Internals.T.splitContinuousAndDiscrete( + let (_c, discrete2) = E.A.Sorted.Floats.splitContinuousAndDiscreteForMinWeight( makeDuplicatedArray(500), + 2, ) let toArr2 = discrete2 |> E.FloatFloatMap.toArray makeTest("splitMedium at count=500", toArr2 |> Belt.Array.length, 500) diff --git a/packages/squiggle-lang/src/rescript/Distributions/SampleSetDist/SampleSetDist_ToPointSet.res b/packages/squiggle-lang/src/rescript/Distributions/SampleSetDist/SampleSetDist_ToPointSet.res index 90537a12..10874f96 100644 --- a/packages/squiggle-lang/src/rescript/Distributions/SampleSetDist/SampleSetDist_ToPointSet.res +++ b/packages/squiggle-lang/src/rescript/Distributions/SampleSetDist/SampleSetDist_ToPointSet.res @@ -39,28 +39,6 @@ module Internals = { module T = { type t = array - let splitContinuousAndDiscrete = (sortedArray: t) => { - let continuous = [] - let discrete = E.FloatFloatMap.empty() - Belt.Array.forEachWithIndex(sortedArray, (index, element) => { - let maxIndex = (sortedArray |> Array.length) - 1 - let possiblySimilarElements = switch index { - | 0 => [index + 1] - | n if n == maxIndex => [index - 1] - | _ => [index - 1, index + 1] - } |> Belt.Array.map(_, r => sortedArray[r]) - let hasSimilarElement = Belt.Array.some(possiblySimilarElements, r => r == element) - hasSimilarElement - ? E.FloatFloatMap.increment(element, discrete) - : { - let _ = Js.Array.push(element, continuous) - } - - () - }) - (continuous, discrete) - } - let xWidthToUnitWidth = (samples, outputXYPoints, xWidth) => { let xyPointRange = E.A.Sorted.range(samples) |> E.O.default(0.0) let xyPointWidth = xyPointRange /. float_of_int(outputXYPoints) @@ -85,7 +63,8 @@ let toPointSetDist = ( (), ): Internals.Types.outputs => { Array.fast_sort(compare, samples) - let (continuousPart, discretePart) = E.A.Sorted.Floats.split(samples) + let minDiscreteToKeep = max(2, E.A.length(samples) / 10); + let (continuousPart, discretePart) = E.A.Sorted.Floats.splitContinuousAndDiscreteForMinWeight(samples, minDiscreteToKeep) let length = samples |> E.A.length |> float_of_int let discrete: PointSetTypes.discreteShape = discretePart diff --git a/packages/squiggle-lang/src/rescript/Utility/E.res b/packages/squiggle-lang/src/rescript/Utility/E.res index 030c2961..69ded69b 100644 --- a/packages/squiggle-lang/src/rescript/Utility/E.res +++ b/packages/squiggle-lang/src/rescript/Utility/E.res @@ -8,7 +8,7 @@ module FloatFloatMap = { type t = Belt.MutableMap.t let fromArray = (ar: array<(float, float)>) => Belt.MutableMap.fromArray(ar, ~id=module(Id)) - let toArray = (t: t) => Belt.MutableMap.toArray(t) + let toArray = (t: t): array<(float, float)> => Belt.MutableMap.toArray(t) let empty = () => Belt.MutableMap.make(~id=module(Id)) let increment = (el, t: t) => Belt.MutableMap.update(t, el, x => @@ -20,6 +20,10 @@ module FloatFloatMap = { let get = (el, t: t) => Belt.MutableMap.get(t, el) let fmap = (fn, t: t) => Belt.MutableMap.map(t, fn) + let partition = (fn, t: t) => { + let (match, noMatch) = Belt.Array.partition(toArray(t), fn) + (fromArray(match), fromArray(noMatch)) + } } module Int = { @@ -518,18 +522,17 @@ module A = { let makeIncrementalDown = (a, b) => Array.make(a - b + 1, a) |> Array.mapi((i, c) => c - i) |> Belt.Array.map(_, float_of_int) - let split = (sortedArray: array) => { - let continuous = [] + let splitContinuousAndDiscreteForDuplicates = (sortedArray: array) => { + let continuous: array = [] let discrete = FloatFloatMap.empty() - Belt.Array.forEachWithIndex(sortedArray, (_, element) => { - // let maxIndex = (sortedArray |> Array.length) - 1 - // let possiblySimilarElements = switch index { - // | 0 => [index + 1] - // | n if n == maxIndex => [index - 1] - // | _ => [index - 1, index + 1] - // } |> Belt.Array.map(_, r => sortedArray[r]) - // let hasSimilarElement = Belt.Array.some(possiblySimilarElements, r => r == element) - let hasSimilarElement = false + Belt.Array.forEachWithIndex(sortedArray, (index, element) => { + let maxIndex = (sortedArray |> Array.length) - 1 + let possiblySimilarElements = switch index { + | 0 => [index + 1] + | n if n == maxIndex => [index - 1] + | _ => [index - 1, index + 1] + } |> Belt.Array.map(_, r => sortedArray[r]) + let hasSimilarElement = Belt.Array.some(possiblySimilarElements, r => r == element) hasSimilarElement ? FloatFloatMap.increment(element, discrete) : { @@ -541,6 +544,26 @@ module A = { (continuous, discrete) } + + let splitContinuousAndDiscreteForMinWeight = ( + sortedArray: array, + minDiscreteWeight: int, + ) => { + let (continuous, discrete) = splitContinuousAndDiscreteForDuplicates(sortedArray) + let keepFn = v => Belt.Float.toInt(v) > minDiscreteWeight + let (discreteToKeep, discreteToIntegrate) = FloatFloatMap.partition( + ((_, v)) => keepFn(v), + discrete, + ) + let newContinousSamples = + discreteToIntegrate->FloatFloatMap.toArray + |> fmap(((k, v)) => Belt.Array.makeBy(Belt.Float.toInt(v), _ => k)) + |> Belt.Array.concatMany + + let newContinuous = concat(continuous, newContinousSamples) + newContinuous |> Array.fast_sort(floatCompare) + (newContinuous, discreteToKeep) + } } }