Gave SampleSetDist a private type
This commit is contained in:
parent
9430653b7a
commit
61aaca3e2f
|
@ -91,8 +91,9 @@ describe("toPointSet", () => {
|
|||
})
|
||||
|
||||
test("on sample set distribution with under 4 points", () => {
|
||||
let sampleSet = SampleSetDist.make([0.0, 1.0, 2.0, 3.0]) -> E.R.toExn;
|
||||
let result =
|
||||
run(FromDist(ToDist(ToPointSet), SampleSet([0.0, 1.0, 2.0, 3.0])))->outputMap(
|
||||
run(FromDist(ToDist(ToPointSet), SampleSet(sampleSet)))->outputMap(
|
||||
FromDist(ToFloat(#Mean)),
|
||||
)
|
||||
expect(result)->toEqual(GenDistError(Other("Converting sampleSet to pointSet failed")))
|
||||
|
|
|
@ -4,12 +4,12 @@ open TestHelpers
|
|||
describe("Continuous and discrete splits", () => {
|
||||
makeTest(
|
||||
"splits (1)",
|
||||
SampleSet.Internals.T.splitContinuousAndDiscrete([1.432, 1.33455, 2.0]),
|
||||
SampleSetDist_ToPointSet.Internals.T.splitContinuousAndDiscrete([1.432, 1.33455, 2.0]),
|
||||
([1.432, 1.33455, 2.0], E.FloatFloatMap.empty()),
|
||||
)
|
||||
makeTest(
|
||||
"splits (2)",
|
||||
SampleSet.Internals.T.splitContinuousAndDiscrete([
|
||||
SampleSetDist_ToPointSet.Internals.T.splitContinuousAndDiscrete([
|
||||
1.432,
|
||||
1.33455,
|
||||
2.0,
|
||||
|
@ -26,13 +26,13 @@ describe("Continuous and discrete splits", () => {
|
|||
E.A.concatMany([sorted, sorted, sorted, sorted]) |> Belt.SortArray.stableSortBy(_, compare)
|
||||
}
|
||||
|
||||
let (_, discrete1) = SampleSet.Internals.T.splitContinuousAndDiscrete(
|
||||
let (_, discrete1) = SampleSetDist_ToPointSet.Internals.T.splitContinuousAndDiscrete(
|
||||
makeDuplicatedArray(10),
|
||||
)
|
||||
let toArr1 = discrete1 |> E.FloatFloatMap.toArray
|
||||
makeTest("splitMedium at count=10", toArr1 |> Belt.Array.length, 10)
|
||||
|
||||
let (_c, discrete2) = SampleSet.Internals.T.splitContinuousAndDiscrete(
|
||||
let (_c, discrete2) = SampleSetDist_ToPointSet.Internals.T.splitContinuousAndDiscrete(
|
||||
makeDuplicatedArray(500),
|
||||
)
|
||||
let toArr2 = discrete2 |> E.FloatFloatMap.toArray
|
||||
|
|
|
@ -128,7 +128,7 @@ let rec run = (~env, functionCallInfo: functionCallInfo): outputType => {
|
|||
->E.R2.fmap(r => Dist(r))
|
||||
->OutputLocal.fromResult
|
||||
| ToDist(ToSampleSet(n)) =>
|
||||
dist->GenericDist.sampleN(n)->E.R2.fmap(r => Dist(SampleSet(r)))->OutputLocal.fromResult
|
||||
dist->GenericDist.toSampleSetDist(n)->E.R2.fmap(r => Dist(SampleSet(r)))->OutputLocal.fromResult
|
||||
| ToDist(ToPointSet) =>
|
||||
dist
|
||||
->GenericDist.toPointSet(~xyPointLength, ~sampleCount, ())
|
||||
|
|
|
@ -2,16 +2,18 @@
|
|||
type t = GenericDist_Types.genericDist
|
||||
type error = GenericDist_Types.error
|
||||
type toPointSetFn = t => result<PointSetTypes.pointSetDist, error>
|
||||
type toSampleSetFn = t => result<array<float>, error>
|
||||
type toSampleSetFn = t => result<SampleSetDist.t, error>
|
||||
type scaleMultiplyFn = (t, float) => result<t, error>
|
||||
type pointwiseAddFn = (t, t) => result<t, error>
|
||||
|
||||
let mapStringErrors = n => n->E.R2.errMap(r => Error(GenericDist_Types.Other(r)))
|
||||
let sampleN = (t: t, n) =>
|
||||
switch t {
|
||||
| PointSet(r) => Ok(PointSetDist.sampleNRendered(n, r))
|
||||
| Symbolic(r) => Ok(SymbolicDist.T.sampleN(n, r))
|
||||
| SampleSet(r) => Ok(SampleSet.sampleN(r, n))
|
||||
| SampleSet(r) => Ok(SampleSetDist.sampleN(r, n))
|
||||
}
|
||||
let toSampleSetDist = (t: t, n) => sampleN(t, n)->E.R.bind(SampleSetDist.make)->mapStringErrors
|
||||
let mapStringErrors = n => n->E.R2.errMap(r => Error(GenericDist_Types.Other(r)))
|
||||
|
||||
let fromFloat = (f: float): t => Symbolic(SymbolicDist.Float.make(f))
|
||||
|
||||
|
@ -63,7 +65,7 @@ let toPointSet = (
|
|||
| PointSet(pointSet) => Ok(pointSet)
|
||||
| Symbolic(r) => Ok(SymbolicDist.T.toPointSetDist(~xSelection, xyPointLength, r))
|
||||
| SampleSet(r) => {
|
||||
let response = SampleSet.toPointSetDist(
|
||||
let response = SampleSetDist.toPointSetDist(
|
||||
~samples=r,
|
||||
~samplingInputs={
|
||||
sampleCount: sampleCount,
|
||||
|
@ -167,8 +169,9 @@ module AlgebraicCombination = {
|
|||
t2: t,
|
||||
) => {
|
||||
let arithmeticOperation = Operation.Algebraic.toFn(arithmeticOperation)
|
||||
E.R.merge(toSampleSet(t1), toSampleSet(t2))->E.R2.fmap(((a, b)) => {
|
||||
Belt.Array.zip(a, b)->E.A2.fmap(((a, b)) => arithmeticOperation(a, b))
|
||||
E.R.merge(toSampleSet(t1), toSampleSet(t2))->E.R.bind(((a, b)) => {
|
||||
SampleSetDist.runMonteCarlo(arithmeticOperation, a, b)
|
||||
->mapStringErrors
|
||||
})
|
||||
}
|
||||
|
||||
|
@ -200,13 +203,15 @@ module AlgebraicCombination = {
|
|||
| Some(Error(e)) => Error(Other(e))
|
||||
| None =>
|
||||
switch chooseConvolutionOrMonteCarlo(t1, t2) {
|
||||
| #CalculateWithMonteCarlo =>
|
||||
runMonteCarlo(
|
||||
| #CalculateWithMonteCarlo => {
|
||||
let sampleSetDist: result<SampleSetDist.t, error> = runMonteCarlo(
|
||||
toSampleSetFn,
|
||||
arithmeticOperation,
|
||||
t1,
|
||||
t2,
|
||||
)->E.R2.fmap(r => GenericDist_Types.SampleSet(r))
|
||||
)
|
||||
sampleSetDist->E.R2.fmap(r => GenericDist_Types.SampleSet(r))
|
||||
}
|
||||
| #CalculateWithConvolution =>
|
||||
runConvolution(
|
||||
toPointSetFn,
|
||||
|
|
|
@ -1,11 +1,12 @@
|
|||
type t = GenericDist_Types.genericDist
|
||||
type error = GenericDist_Types.error
|
||||
type toPointSetFn = t => result<PointSetTypes.pointSetDist, error>
|
||||
type toSampleSetFn = t => result<array<float>, error>
|
||||
type toSampleSetFn = t => result<SampleSetDist.t, error>
|
||||
type scaleMultiplyFn = (t, float) => result<t, error>
|
||||
type pointwiseAddFn = (t, t) => result<t, error>
|
||||
|
||||
let sampleN: (t, int) => result<array<float>, error>
|
||||
let toSampleSetDist: (t, int) => Belt.Result.t<QuriSquiggleLang.SampleSetDist.t, error>
|
||||
|
||||
let fromFloat: float => t
|
||||
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
type genericDist =
|
||||
| PointSet(PointSetTypes.pointSetDist)
|
||||
| SampleSet(SampleSet.t)
|
||||
| SampleSet(SampleSetDist.t)
|
||||
| Symbolic(SymbolicDistTypes.symbolicDist)
|
||||
|
||||
@genType
|
||||
|
|
|
@ -0,0 +1,48 @@
|
|||
module T: {
|
||||
type t
|
||||
let make: array<float> => result<t, string>
|
||||
let get: t => array<float>
|
||||
} = {
|
||||
type t = array<float>
|
||||
let make = (a: array<float>) =>
|
||||
if E.A.length(a) > 5 {
|
||||
Ok(a)
|
||||
} else {
|
||||
Error("too small")
|
||||
}
|
||||
let get = (a: t) => a
|
||||
}
|
||||
|
||||
include T
|
||||
|
||||
// TODO: Refactor to raise correct error when not enough samples
|
||||
|
||||
let toPointSetDist = (~samples: t, ~samplingInputs: SamplingInputs.samplingInputs, ()) =>
|
||||
SampleSetDist_ToPointSet.toPointSetDist(~samples=get(samples), ~samplingInputs, ())
|
||||
|
||||
//Randomly get one sample from the distribution
|
||||
let sample = (t: t): float => {
|
||||
let i = E.Int.random(~min=0, ~max=E.A.length(get(t)) - 1)
|
||||
E.A.unsafe_get(get(t), i)
|
||||
}
|
||||
|
||||
/*
|
||||
If asked for a length of samples shorter or equal the length of the distribution,
|
||||
return this first n samples of this distribution.
|
||||
Else, return n random samples of the distribution.
|
||||
The former helps in cases where multiple distributions are correlated.
|
||||
However, if n > length(t), then there's no clear right answer, so we just randomly
|
||||
sample everything.
|
||||
*/
|
||||
let sampleN = (t: t, n) => {
|
||||
if n <= E.A.length(get(t)) {
|
||||
E.A.slice(get(t), ~offset=0, ~len=n)
|
||||
} else {
|
||||
Belt.Array.makeBy(n, _ => sample(t))
|
||||
}
|
||||
}
|
||||
|
||||
let runMonteCarlo = (fn: (float, float) => float, t1: t, t2: t) => {
|
||||
let samples = Belt.Array.zip(get(t1), get(t2))->E.A2.fmap(((a, b)) => fn(a, b))
|
||||
make(samples)
|
||||
}
|
|
@ -1,8 +1,3 @@
|
|||
@genType
|
||||
type t = array<float>
|
||||
|
||||
// TODO: Refactor to raise correct error when not enough samples
|
||||
|
||||
module Internals = {
|
||||
module Types = {
|
||||
type samplingStats = {
|
||||
|
@ -145,25 +140,3 @@ let toPointSetDist = (
|
|||
|
||||
samplesParse
|
||||
}
|
||||
|
||||
//Randomly get one sample from the distribution
|
||||
let sample = (t: t): float => {
|
||||
let i = E.Int.random(~min=0, ~max=E.A.length(t) - 1)
|
||||
E.A.unsafe_get(t, i)
|
||||
}
|
||||
|
||||
/*
|
||||
If asked for a length of samples shorter or equal the length of the distribution,
|
||||
return this first n samples of this distribution.
|
||||
Else, return n random samples of the distribution.
|
||||
The former helps in cases where multiple distributions are correlated.
|
||||
However, if n > length(t), then there's no clear right answer, so we just randomly
|
||||
sample everything.
|
||||
*/
|
||||
let sampleN = (t: t, n) => {
|
||||
if n <= E.A.length(t) {
|
||||
E.A.slice(t, ~offset=0, ~len=n)
|
||||
} else {
|
||||
Belt.Array.makeBy(n, _ => sample(t))
|
||||
}
|
||||
}
|
|
@ -218,15 +218,15 @@ module SamplingDistribution = {
|
|||
algebraicOp,
|
||||
a,
|
||||
b,
|
||||
)
|
||||
) |> E.O.toResult("Could not get samples")
|
||||
|
||||
let sampleSetDist = samples -> E.R.bind(SampleSetDist.make)
|
||||
|
||||
let pointSetDist =
|
||||
samples
|
||||
|> E.O.fmap(r =>
|
||||
SampleSet.toPointSetDist(~samplingInputs=evaluationParams.samplingInputs, ~samples=r, ())
|
||||
)
|
||||
|> E.O.bind(_, r => r.pointSetDist)
|
||||
|> E.O.toResult("No response")
|
||||
sampleSetDist
|
||||
-> E.R2.fmap(r =>
|
||||
SampleSetDist.toPointSetDist(~samplingInputs=evaluationParams.samplingInputs, ~samples=r, ()))
|
||||
-> E.R.bind(r => r.pointSetDist |> E.O.toResult("combineShapesUsingSampling Error"))
|
||||
pointSetDist |> E.R.fmap(r => #Normalize(#RenderedDist(r)))
|
||||
})
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue
Block a user