Gave SampleSetDist a private type

This commit is contained in:
Ozzie Gooen 2022-04-09 18:10:06 -04:00
parent 9430653b7a
commit 61aaca3e2f
9 changed files with 83 additions and 55 deletions

View File

@ -91,8 +91,9 @@ describe("toPointSet", () => {
})
test("on sample set distribution with under 4 points", () => {
let sampleSet = SampleSetDist.make([0.0, 1.0, 2.0, 3.0]) -> E.R.toExn;
let result =
run(FromDist(ToDist(ToPointSet), SampleSet([0.0, 1.0, 2.0, 3.0])))->outputMap(
run(FromDist(ToDist(ToPointSet), SampleSet(sampleSet)))->outputMap(
FromDist(ToFloat(#Mean)),
)
expect(result)->toEqual(GenDistError(Other("Converting sampleSet to pointSet failed")))

View File

@ -4,12 +4,12 @@ open TestHelpers
describe("Continuous and discrete splits", () => {
makeTest(
"splits (1)",
SampleSet.Internals.T.splitContinuousAndDiscrete([1.432, 1.33455, 2.0]),
SampleSetDist_ToPointSet.Internals.T.splitContinuousAndDiscrete([1.432, 1.33455, 2.0]),
([1.432, 1.33455, 2.0], E.FloatFloatMap.empty()),
)
makeTest(
"splits (2)",
SampleSet.Internals.T.splitContinuousAndDiscrete([
SampleSetDist_ToPointSet.Internals.T.splitContinuousAndDiscrete([
1.432,
1.33455,
2.0,
@ -26,13 +26,13 @@ describe("Continuous and discrete splits", () => {
E.A.concatMany([sorted, sorted, sorted, sorted]) |> Belt.SortArray.stableSortBy(_, compare)
}
let (_, discrete1) = SampleSet.Internals.T.splitContinuousAndDiscrete(
let (_, discrete1) = SampleSetDist_ToPointSet.Internals.T.splitContinuousAndDiscrete(
makeDuplicatedArray(10),
)
let toArr1 = discrete1 |> E.FloatFloatMap.toArray
makeTest("splitMedium at count=10", toArr1 |> Belt.Array.length, 10)
let (_c, discrete2) = SampleSet.Internals.T.splitContinuousAndDiscrete(
let (_c, discrete2) = SampleSetDist_ToPointSet.Internals.T.splitContinuousAndDiscrete(
makeDuplicatedArray(500),
)
let toArr2 = discrete2 |> E.FloatFloatMap.toArray

View File

@ -128,7 +128,7 @@ let rec run = (~env, functionCallInfo: functionCallInfo): outputType => {
->E.R2.fmap(r => Dist(r))
->OutputLocal.fromResult
| ToDist(ToSampleSet(n)) =>
dist->GenericDist.sampleN(n)->E.R2.fmap(r => Dist(SampleSet(r)))->OutputLocal.fromResult
dist->GenericDist.toSampleSetDist(n)->E.R2.fmap(r => Dist(SampleSet(r)))->OutputLocal.fromResult
| ToDist(ToPointSet) =>
dist
->GenericDist.toPointSet(~xyPointLength, ~sampleCount, ())

View File

@ -2,16 +2,18 @@
type t = GenericDist_Types.genericDist
type error = GenericDist_Types.error
type toPointSetFn = t => result<PointSetTypes.pointSetDist, error>
type toSampleSetFn = t => result<array<float>, error>
type toSampleSetFn = t => result<SampleSetDist.t, error>
type scaleMultiplyFn = (t, float) => result<t, error>
type pointwiseAddFn = (t, t) => result<t, error>
let mapStringErrors = n => n->E.R2.errMap(r => Error(GenericDist_Types.Other(r)))
let sampleN = (t: t, n) =>
switch t {
| PointSet(r) => Ok(PointSetDist.sampleNRendered(n, r))
| Symbolic(r) => Ok(SymbolicDist.T.sampleN(n, r))
| SampleSet(r) => Ok(SampleSet.sampleN(r, n))
| SampleSet(r) => Ok(SampleSetDist.sampleN(r, n))
}
let toSampleSetDist = (t: t, n) => sampleN(t, n)->E.R.bind(SampleSetDist.make)->mapStringErrors
let mapStringErrors = n => n->E.R2.errMap(r => Error(GenericDist_Types.Other(r)))
let fromFloat = (f: float): t => Symbolic(SymbolicDist.Float.make(f))
@ -63,7 +65,7 @@ let toPointSet = (
| PointSet(pointSet) => Ok(pointSet)
| Symbolic(r) => Ok(SymbolicDist.T.toPointSetDist(~xSelection, xyPointLength, r))
| SampleSet(r) => {
let response = SampleSet.toPointSetDist(
let response = SampleSetDist.toPointSetDist(
~samples=r,
~samplingInputs={
sampleCount: sampleCount,
@ -167,8 +169,9 @@ module AlgebraicCombination = {
t2: t,
) => {
let arithmeticOperation = Operation.Algebraic.toFn(arithmeticOperation)
E.R.merge(toSampleSet(t1), toSampleSet(t2))->E.R2.fmap(((a, b)) => {
Belt.Array.zip(a, b)->E.A2.fmap(((a, b)) => arithmeticOperation(a, b))
E.R.merge(toSampleSet(t1), toSampleSet(t2))->E.R.bind(((a, b)) => {
SampleSetDist.runMonteCarlo(arithmeticOperation, a, b)
->mapStringErrors
})
}
@ -200,13 +203,15 @@ module AlgebraicCombination = {
| Some(Error(e)) => Error(Other(e))
| None =>
switch chooseConvolutionOrMonteCarlo(t1, t2) {
| #CalculateWithMonteCarlo =>
runMonteCarlo(
toSampleSetFn,
arithmeticOperation,
t1,
t2,
)->E.R2.fmap(r => GenericDist_Types.SampleSet(r))
| #CalculateWithMonteCarlo => {
let sampleSetDist: result<SampleSetDist.t, error> = runMonteCarlo(
toSampleSetFn,
arithmeticOperation,
t1,
t2,
)
sampleSetDist->E.R2.fmap(r => GenericDist_Types.SampleSet(r))
}
| #CalculateWithConvolution =>
runConvolution(
toPointSetFn,

View File

@ -1,11 +1,12 @@
type t = GenericDist_Types.genericDist
type error = GenericDist_Types.error
type toPointSetFn = t => result<PointSetTypes.pointSetDist, error>
type toSampleSetFn = t => result<array<float>, error>
type toSampleSetFn = t => result<SampleSetDist.t, error>
type scaleMultiplyFn = (t, float) => result<t, error>
type pointwiseAddFn = (t, t) => result<t, error>
let sampleN: (t, int) => result<array<float>, error>
let toSampleSetDist: (t, int) => Belt.Result.t<QuriSquiggleLang.SampleSetDist.t, error>
let fromFloat: float => t

View File

@ -1,6 +1,6 @@
type genericDist =
| PointSet(PointSetTypes.pointSetDist)
| SampleSet(SampleSet.t)
| SampleSet(SampleSetDist.t)
| Symbolic(SymbolicDistTypes.symbolicDist)
@genType

View File

@ -0,0 +1,48 @@
module T: {
type t
let make: array<float> => result<t, string>
let get: t => array<float>
} = {
type t = array<float>
let make = (a: array<float>) =>
if E.A.length(a) > 5 {
Ok(a)
} else {
Error("too small")
}
let get = (a: t) => a
}
include T
// TODO: Refactor to raise correct error when not enough samples
let toPointSetDist = (~samples: t, ~samplingInputs: SamplingInputs.samplingInputs, ()) =>
SampleSetDist_ToPointSet.toPointSetDist(~samples=get(samples), ~samplingInputs, ())
//Randomly get one sample from the distribution
let sample = (t: t): float => {
let i = E.Int.random(~min=0, ~max=E.A.length(get(t)) - 1)
E.A.unsafe_get(get(t), i)
}
/*
If asked for a length of samples shorter or equal the length of the distribution,
return this first n samples of this distribution.
Else, return n random samples of the distribution.
The former helps in cases where multiple distributions are correlated.
However, if n > length(t), then there's no clear right answer, so we just randomly
sample everything.
*/
let sampleN = (t: t, n) => {
if n <= E.A.length(get(t)) {
E.A.slice(get(t), ~offset=0, ~len=n)
} else {
Belt.Array.makeBy(n, _ => sample(t))
}
}
let runMonteCarlo = (fn: (float, float) => float, t1: t, t2: t) => {
let samples = Belt.Array.zip(get(t1), get(t2))->E.A2.fmap(((a, b)) => fn(a, b))
make(samples)
}

View File

@ -1,8 +1,3 @@
@genType
type t = array<float>
// TODO: Refactor to raise correct error when not enough samples
module Internals = {
module Types = {
type samplingStats = {
@ -145,25 +140,3 @@ let toPointSetDist = (
samplesParse
}
//Randomly get one sample from the distribution
let sample = (t: t): float => {
let i = E.Int.random(~min=0, ~max=E.A.length(t) - 1)
E.A.unsafe_get(t, i)
}
/*
If asked for a length of samples shorter or equal the length of the distribution,
return this first n samples of this distribution.
Else, return n random samples of the distribution.
The former helps in cases where multiple distributions are correlated.
However, if n > length(t), then there's no clear right answer, so we just randomly
sample everything.
*/
let sampleN = (t: t, n) => {
if n <= E.A.length(t) {
E.A.slice(t, ~offset=0, ~len=n)
} else {
Belt.Array.makeBy(n, _ => sample(t))
}
}

View File

@ -218,15 +218,15 @@ module SamplingDistribution = {
algebraicOp,
a,
b,
)
) |> E.O.toResult("Could not get samples")
let sampleSetDist = samples -> E.R.bind(SampleSetDist.make)
let pointSetDist =
samples
|> E.O.fmap(r =>
SampleSet.toPointSetDist(~samplingInputs=evaluationParams.samplingInputs, ~samples=r, ())
)
|> E.O.bind(_, r => r.pointSetDist)
|> E.O.toResult("No response")
sampleSetDist
-> E.R2.fmap(r =>
SampleSetDist.toPointSetDist(~samplingInputs=evaluationParams.samplingInputs, ~samples=r, ()))
-> E.R.bind(r => r.pointSetDist |> E.O.toResult("combineShapesUsingSampling Error"))
pointSetDist |> E.R.fmap(r => #Normalize(#RenderedDist(r)))
})
}