Gave SampleSetDist a private type

This commit is contained in:
Ozzie Gooen 2022-04-09 18:10:06 -04:00
parent 9430653b7a
commit 61aaca3e2f
9 changed files with 83 additions and 55 deletions

View File

@ -91,8 +91,9 @@ describe("toPointSet", () => {
}) })
test("on sample set distribution with under 4 points", () => { test("on sample set distribution with under 4 points", () => {
let sampleSet = SampleSetDist.make([0.0, 1.0, 2.0, 3.0]) -> E.R.toExn;
let result = let result =
run(FromDist(ToDist(ToPointSet), SampleSet([0.0, 1.0, 2.0, 3.0])))->outputMap( run(FromDist(ToDist(ToPointSet), SampleSet(sampleSet)))->outputMap(
FromDist(ToFloat(#Mean)), FromDist(ToFloat(#Mean)),
) )
expect(result)->toEqual(GenDistError(Other("Converting sampleSet to pointSet failed"))) expect(result)->toEqual(GenDistError(Other("Converting sampleSet to pointSet failed")))

View File

@ -4,12 +4,12 @@ open TestHelpers
describe("Continuous and discrete splits", () => { describe("Continuous and discrete splits", () => {
makeTest( makeTest(
"splits (1)", "splits (1)",
SampleSet.Internals.T.splitContinuousAndDiscrete([1.432, 1.33455, 2.0]), SampleSetDist_ToPointSet.Internals.T.splitContinuousAndDiscrete([1.432, 1.33455, 2.0]),
([1.432, 1.33455, 2.0], E.FloatFloatMap.empty()), ([1.432, 1.33455, 2.0], E.FloatFloatMap.empty()),
) )
makeTest( makeTest(
"splits (2)", "splits (2)",
SampleSet.Internals.T.splitContinuousAndDiscrete([ SampleSetDist_ToPointSet.Internals.T.splitContinuousAndDiscrete([
1.432, 1.432,
1.33455, 1.33455,
2.0, 2.0,
@ -26,13 +26,13 @@ describe("Continuous and discrete splits", () => {
E.A.concatMany([sorted, sorted, sorted, sorted]) |> Belt.SortArray.stableSortBy(_, compare) E.A.concatMany([sorted, sorted, sorted, sorted]) |> Belt.SortArray.stableSortBy(_, compare)
} }
let (_, discrete1) = SampleSet.Internals.T.splitContinuousAndDiscrete( let (_, discrete1) = SampleSetDist_ToPointSet.Internals.T.splitContinuousAndDiscrete(
makeDuplicatedArray(10), makeDuplicatedArray(10),
) )
let toArr1 = discrete1 |> E.FloatFloatMap.toArray let toArr1 = discrete1 |> E.FloatFloatMap.toArray
makeTest("splitMedium at count=10", toArr1 |> Belt.Array.length, 10) makeTest("splitMedium at count=10", toArr1 |> Belt.Array.length, 10)
let (_c, discrete2) = SampleSet.Internals.T.splitContinuousAndDiscrete( let (_c, discrete2) = SampleSetDist_ToPointSet.Internals.T.splitContinuousAndDiscrete(
makeDuplicatedArray(500), makeDuplicatedArray(500),
) )
let toArr2 = discrete2 |> E.FloatFloatMap.toArray let toArr2 = discrete2 |> E.FloatFloatMap.toArray

View File

@ -128,7 +128,7 @@ let rec run = (~env, functionCallInfo: functionCallInfo): outputType => {
->E.R2.fmap(r => Dist(r)) ->E.R2.fmap(r => Dist(r))
->OutputLocal.fromResult ->OutputLocal.fromResult
| ToDist(ToSampleSet(n)) => | ToDist(ToSampleSet(n)) =>
dist->GenericDist.sampleN(n)->E.R2.fmap(r => Dist(SampleSet(r)))->OutputLocal.fromResult dist->GenericDist.toSampleSetDist(n)->E.R2.fmap(r => Dist(SampleSet(r)))->OutputLocal.fromResult
| ToDist(ToPointSet) => | ToDist(ToPointSet) =>
dist dist
->GenericDist.toPointSet(~xyPointLength, ~sampleCount, ()) ->GenericDist.toPointSet(~xyPointLength, ~sampleCount, ())

View File

@ -2,16 +2,18 @@
type t = GenericDist_Types.genericDist type t = GenericDist_Types.genericDist
type error = GenericDist_Types.error type error = GenericDist_Types.error
type toPointSetFn = t => result<PointSetTypes.pointSetDist, error> type toPointSetFn = t => result<PointSetTypes.pointSetDist, error>
type toSampleSetFn = t => result<array<float>, error> type toSampleSetFn = t => result<SampleSetDist.t, error>
type scaleMultiplyFn = (t, float) => result<t, error> type scaleMultiplyFn = (t, float) => result<t, error>
type pointwiseAddFn = (t, t) => result<t, error> type pointwiseAddFn = (t, t) => result<t, error>
let mapStringErrors = n => n->E.R2.errMap(r => Error(GenericDist_Types.Other(r)))
let sampleN = (t: t, n) => let sampleN = (t: t, n) =>
switch t { switch t {
| PointSet(r) => Ok(PointSetDist.sampleNRendered(n, r)) | PointSet(r) => Ok(PointSetDist.sampleNRendered(n, r))
| Symbolic(r) => Ok(SymbolicDist.T.sampleN(n, r)) | Symbolic(r) => Ok(SymbolicDist.T.sampleN(n, r))
| SampleSet(r) => Ok(SampleSet.sampleN(r, n)) | SampleSet(r) => Ok(SampleSetDist.sampleN(r, n))
} }
let toSampleSetDist = (t: t, n) => sampleN(t, n)->E.R.bind(SampleSetDist.make)->mapStringErrors
let mapStringErrors = n => n->E.R2.errMap(r => Error(GenericDist_Types.Other(r)))
let fromFloat = (f: float): t => Symbolic(SymbolicDist.Float.make(f)) let fromFloat = (f: float): t => Symbolic(SymbolicDist.Float.make(f))
@ -63,7 +65,7 @@ let toPointSet = (
| PointSet(pointSet) => Ok(pointSet) | PointSet(pointSet) => Ok(pointSet)
| Symbolic(r) => Ok(SymbolicDist.T.toPointSetDist(~xSelection, xyPointLength, r)) | Symbolic(r) => Ok(SymbolicDist.T.toPointSetDist(~xSelection, xyPointLength, r))
| SampleSet(r) => { | SampleSet(r) => {
let response = SampleSet.toPointSetDist( let response = SampleSetDist.toPointSetDist(
~samples=r, ~samples=r,
~samplingInputs={ ~samplingInputs={
sampleCount: sampleCount, sampleCount: sampleCount,
@ -167,8 +169,9 @@ module AlgebraicCombination = {
t2: t, t2: t,
) => { ) => {
let arithmeticOperation = Operation.Algebraic.toFn(arithmeticOperation) let arithmeticOperation = Operation.Algebraic.toFn(arithmeticOperation)
E.R.merge(toSampleSet(t1), toSampleSet(t2))->E.R2.fmap(((a, b)) => { E.R.merge(toSampleSet(t1), toSampleSet(t2))->E.R.bind(((a, b)) => {
Belt.Array.zip(a, b)->E.A2.fmap(((a, b)) => arithmeticOperation(a, b)) SampleSetDist.runMonteCarlo(arithmeticOperation, a, b)
->mapStringErrors
}) })
} }
@ -200,13 +203,15 @@ module AlgebraicCombination = {
| Some(Error(e)) => Error(Other(e)) | Some(Error(e)) => Error(Other(e))
| None => | None =>
switch chooseConvolutionOrMonteCarlo(t1, t2) { switch chooseConvolutionOrMonteCarlo(t1, t2) {
| #CalculateWithMonteCarlo => | #CalculateWithMonteCarlo => {
runMonteCarlo( let sampleSetDist: result<SampleSetDist.t, error> = runMonteCarlo(
toSampleSetFn, toSampleSetFn,
arithmeticOperation, arithmeticOperation,
t1, t1,
t2, t2,
)->E.R2.fmap(r => GenericDist_Types.SampleSet(r)) )
sampleSetDist->E.R2.fmap(r => GenericDist_Types.SampleSet(r))
}
| #CalculateWithConvolution => | #CalculateWithConvolution =>
runConvolution( runConvolution(
toPointSetFn, toPointSetFn,

View File

@ -1,11 +1,12 @@
type t = GenericDist_Types.genericDist type t = GenericDist_Types.genericDist
type error = GenericDist_Types.error type error = GenericDist_Types.error
type toPointSetFn = t => result<PointSetTypes.pointSetDist, error> type toPointSetFn = t => result<PointSetTypes.pointSetDist, error>
type toSampleSetFn = t => result<array<float>, error> type toSampleSetFn = t => result<SampleSetDist.t, error>
type scaleMultiplyFn = (t, float) => result<t, error> type scaleMultiplyFn = (t, float) => result<t, error>
type pointwiseAddFn = (t, t) => result<t, error> type pointwiseAddFn = (t, t) => result<t, error>
let sampleN: (t, int) => result<array<float>, error> let sampleN: (t, int) => result<array<float>, error>
let toSampleSetDist: (t, int) => Belt.Result.t<QuriSquiggleLang.SampleSetDist.t, error>
let fromFloat: float => t let fromFloat: float => t

View File

@ -1,6 +1,6 @@
type genericDist = type genericDist =
| PointSet(PointSetTypes.pointSetDist) | PointSet(PointSetTypes.pointSetDist)
| SampleSet(SampleSet.t) | SampleSet(SampleSetDist.t)
| Symbolic(SymbolicDistTypes.symbolicDist) | Symbolic(SymbolicDistTypes.symbolicDist)
@genType @genType

View File

@ -0,0 +1,48 @@
module T: {
type t
let make: array<float> => result<t, string>
let get: t => array<float>
} = {
type t = array<float>
let make = (a: array<float>) =>
if E.A.length(a) > 5 {
Ok(a)
} else {
Error("too small")
}
let get = (a: t) => a
}
include T
// TODO: Refactor to raise correct error when not enough samples
let toPointSetDist = (~samples: t, ~samplingInputs: SamplingInputs.samplingInputs, ()) =>
SampleSetDist_ToPointSet.toPointSetDist(~samples=get(samples), ~samplingInputs, ())
//Randomly get one sample from the distribution
let sample = (t: t): float => {
let i = E.Int.random(~min=0, ~max=E.A.length(get(t)) - 1)
E.A.unsafe_get(get(t), i)
}
/*
If asked for a length of samples shorter or equal the length of the distribution,
return this first n samples of this distribution.
Else, return n random samples of the distribution.
The former helps in cases where multiple distributions are correlated.
However, if n > length(t), then there's no clear right answer, so we just randomly
sample everything.
*/
let sampleN = (t: t, n) => {
if n <= E.A.length(get(t)) {
E.A.slice(get(t), ~offset=0, ~len=n)
} else {
Belt.Array.makeBy(n, _ => sample(t))
}
}
let runMonteCarlo = (fn: (float, float) => float, t1: t, t2: t) => {
let samples = Belt.Array.zip(get(t1), get(t2))->E.A2.fmap(((a, b)) => fn(a, b))
make(samples)
}

View File

@ -1,8 +1,3 @@
@genType
type t = array<float>
// TODO: Refactor to raise correct error when not enough samples
module Internals = { module Internals = {
module Types = { module Types = {
type samplingStats = { type samplingStats = {
@ -145,25 +140,3 @@ let toPointSetDist = (
samplesParse samplesParse
} }
//Randomly get one sample from the distribution
let sample = (t: t): float => {
let i = E.Int.random(~min=0, ~max=E.A.length(t) - 1)
E.A.unsafe_get(t, i)
}
/*
If asked for a length of samples shorter or equal the length of the distribution,
return this first n samples of this distribution.
Else, return n random samples of the distribution.
The former helps in cases where multiple distributions are correlated.
However, if n > length(t), then there's no clear right answer, so we just randomly
sample everything.
*/
let sampleN = (t: t, n) => {
if n <= E.A.length(t) {
E.A.slice(t, ~offset=0, ~len=n)
} else {
Belt.Array.makeBy(n, _ => sample(t))
}
}

View File

@ -218,15 +218,15 @@ module SamplingDistribution = {
algebraicOp, algebraicOp,
a, a,
b, b,
) ) |> E.O.toResult("Could not get samples")
let sampleSetDist = samples -> E.R.bind(SampleSetDist.make)
let pointSetDist = let pointSetDist =
samples sampleSetDist
|> E.O.fmap(r => -> E.R2.fmap(r =>
SampleSet.toPointSetDist(~samplingInputs=evaluationParams.samplingInputs, ~samples=r, ()) SampleSetDist.toPointSetDist(~samplingInputs=evaluationParams.samplingInputs, ~samples=r, ()))
) -> E.R.bind(r => r.pointSetDist |> E.O.toResult("combineShapesUsingSampling Error"))
|> E.O.bind(_, r => r.pointSetDist)
|> E.O.toResult("No response")
pointSetDist |> E.R.fmap(r => #Normalize(#RenderedDist(r))) pointSetDist |> E.R.fmap(r => #Normalize(#RenderedDist(r)))
}) })
} }