Merge pull request #211 from QURIresearch/Refactor-april

SampleSetDist refactors
This commit is contained in:
Ozzie Gooen 2022-04-09 22:01:46 -04:00 committed by GitHub
commit 66c1081218
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
22 changed files with 177 additions and 130 deletions

View File

@ -4,10 +4,10 @@ open Expect
describe("Bandwidth", () => { describe("Bandwidth", () => {
test("nrd0()", () => { test("nrd0()", () => {
let data = [1., 4., 3., 2.] let data = [1., 4., 3., 2.]
expect(Bandwidth.nrd0(data)) -> toEqual(0.7625801874014622) expect(SampleSetDist_Bandwidth.nrd0(data)) -> toEqual(0.7625801874014622)
}) })
test("nrd()", () => { test("nrd()", () => {
let data = [1., 4., 3., 2.] let data = [1., 4., 3., 2.]
expect(Bandwidth.nrd(data)) -> toEqual(0.8981499984950554) expect(SampleSetDist_Bandwidth.nrd(data)) -> toEqual(0.8981499984950554)
}) })
}) })

View File

@ -90,14 +90,6 @@ describe("toPointSet", () => {
expect(result)->toBeSoCloseTo(5.0, ~digits=0) expect(result)->toBeSoCloseTo(5.0, ~digits=0)
}) })
test("on sample set distribution with under 4 points", () => {
let result =
run(FromDist(ToDist(ToPointSet), SampleSet([0.0, 1.0, 2.0, 3.0])))->outputMap(
FromDist(ToFloat(#Mean)),
)
expect(result)->toEqual(GenDistError(Other("Converting sampleSet to pointSet failed")))
})
test("on sample set", () => { test("on sample set", () => {
let result = let result =
run(FromDist(ToDist(ToPointSet), normalDist5)) run(FromDist(ToDist(ToPointSet), normalDist5))

View File

@ -4,12 +4,12 @@ open TestHelpers
describe("Continuous and discrete splits", () => { describe("Continuous and discrete splits", () => {
makeTest( makeTest(
"splits (1)", "splits (1)",
SampleSet.Internals.T.splitContinuousAndDiscrete([1.432, 1.33455, 2.0]), SampleSetDist_ToPointSet.Internals.T.splitContinuousAndDiscrete([1.432, 1.33455, 2.0]),
([1.432, 1.33455, 2.0], E.FloatFloatMap.empty()), ([1.432, 1.33455, 2.0], E.FloatFloatMap.empty()),
) )
makeTest( makeTest(
"splits (2)", "splits (2)",
SampleSet.Internals.T.splitContinuousAndDiscrete([ SampleSetDist_ToPointSet.Internals.T.splitContinuousAndDiscrete([
1.432, 1.432,
1.33455, 1.33455,
2.0, 2.0,
@ -26,13 +26,13 @@ describe("Continuous and discrete splits", () => {
E.A.concatMany([sorted, sorted, sorted, sorted]) |> Belt.SortArray.stableSortBy(_, compare) E.A.concatMany([sorted, sorted, sorted, sorted]) |> Belt.SortArray.stableSortBy(_, compare)
} }
let (_, discrete1) = SampleSet.Internals.T.splitContinuousAndDiscrete( let (_, discrete1) = SampleSetDist_ToPointSet.Internals.T.splitContinuousAndDiscrete(
makeDuplicatedArray(10), makeDuplicatedArray(10),
) )
let toArr1 = discrete1 |> E.FloatFloatMap.toArray let toArr1 = discrete1 |> E.FloatFloatMap.toArray
makeTest("splitMedium at count=10", toArr1 |> Belt.Array.length, 10) makeTest("splitMedium at count=10", toArr1 |> Belt.Array.length, 10)
let (_c, discrete2) = SampleSet.Internals.T.splitContinuousAndDiscrete( let (_c, discrete2) = SampleSetDist_ToPointSet.Internals.T.splitContinuousAndDiscrete(
makeDuplicatedArray(500), makeDuplicatedArray(500),
) )
let toArr2 = discrete2 |> E.FloatFloatMap.toArray let toArr2 = discrete2 |> E.FloatFloatMap.toArray

View File

@ -1,4 +1,4 @@
import { run, GenericDist, resultMap } from "../src/js/index"; import { run, GenericDist, resultMap, makeSampleSetDist } from "../src/js/index";
let testRun = (x: string) => { let testRun = (x: string) => {
let result = run(x); let result = run(x);
@ -41,6 +41,7 @@ describe("Multimodal too many weights error", () => {
describe("GenericDist", () => { describe("GenericDist", () => {
//It's important that sampleCount is less than 9. If it's more, than that will create randomness //It's important that sampleCount is less than 9. If it's more, than that will create randomness
//Also, note, the value should be created using makeSampleSetDist() later on.
let env = { sampleCount: 8, xyPointLength: 100 }; let env = { sampleCount: 8, xyPointLength: 100 };
let dist = new GenericDist( let dist = new GenericDist(
{ tag: "SampleSet", value: [3, 4, 5, 6, 6, 7, 10, 15, 30] }, { tag: "SampleSet", value: [3, 4, 5, 6, 6, 7, 10, 15, 30] },

View File

@ -14,6 +14,7 @@ import {
resultFloat, resultFloat,
resultString, resultString,
} from "../rescript/TypescriptInterface.gen"; } from "../rescript/TypescriptInterface.gen";
export {makeSampleSetDist} from "../rescript/TypescriptInterface.gen";
import { import {
Constructors_mean, Constructors_mean,
Constructors_sample, Constructors_sample,
@ -32,13 +33,13 @@ import {
Constructors_algebraicDivide, Constructors_algebraicDivide,
Constructors_algebraicSubtract, Constructors_algebraicSubtract,
Constructors_algebraicLogarithm, Constructors_algebraicLogarithm,
Constructors_algebraicExponentiate, Constructors_algebraicPower,
Constructors_pointwiseAdd, Constructors_pointwiseAdd,
Constructors_pointwiseMultiply, Constructors_pointwiseMultiply,
Constructors_pointwiseDivide, Constructors_pointwiseDivide,
Constructors_pointwiseSubtract, Constructors_pointwiseSubtract,
Constructors_pointwiseLogarithm, Constructors_pointwiseLogarithm,
Constructors_pointwiseExponentiate, Constructors_pointwisePower,
} from "../rescript/Distributions/DistributionOperation/DistributionOperation.gen"; } from "../rescript/Distributions/DistributionOperation/DistributionOperation.gen";
export let defaultSamplingInputs: SamplingInputs = { export let defaultSamplingInputs: SamplingInputs = {
@ -79,6 +80,10 @@ export function resultMap(r: result, mapFn: any): result {
} }
} }
export function resultExn(r: result): any {
r.value
}
export class GenericDist { export class GenericDist {
t: genericDist; t: genericDist;
env: env; env: env;
@ -179,9 +184,9 @@ export class GenericDist {
); );
} }
algebraicExponentiate(d2: GenericDist) { algebraicPower(d2: GenericDist) {
return this.mapResultDist( return this.mapResultDist(
Constructors_algebraicExponentiate({ env: this.env }, this.t, d2.t) Constructors_algebraicPower({ env: this.env }, this.t, d2.t)
); );
} }
@ -215,9 +220,9 @@ export class GenericDist {
); );
} }
pointwiseExponentiate(d2: GenericDist) { pointwisePower(d2: GenericDist) {
return this.mapResultDist( return this.mapResultDist(
Constructors_pointwiseExponentiate({ env: this.env }, this.t, d2.t) Constructors_pointwisePower({ env: this.env }, this.t, d2.t)
); );
} }
} }

View File

@ -128,7 +128,10 @@ let rec run = (~env, functionCallInfo: functionCallInfo): outputType => {
->E.R2.fmap(r => Dist(r)) ->E.R2.fmap(r => Dist(r))
->OutputLocal.fromResult ->OutputLocal.fromResult
| ToDist(ToSampleSet(n)) => | ToDist(ToSampleSet(n)) =>
dist->GenericDist.sampleN(n)->E.R2.fmap(r => Dist(SampleSet(r)))->OutputLocal.fromResult dist
->GenericDist.toSampleSetDist(n)
->E.R2.fmap(r => Dist(SampleSet(r)))
->OutputLocal.fromResult
| ToDist(ToPointSet) => | ToDist(ToPointSet) =>
dist dist
->GenericDist.toPointSet(~xyPointLength, ~sampleCount, ()) ->GenericDist.toPointSet(~xyPointLength, ~sampleCount, ())
@ -204,7 +207,8 @@ module Constructors = {
C.truncate(dist, leftCutoff, rightCutoff)->run(~env)->toDistR C.truncate(dist, leftCutoff, rightCutoff)->run(~env)->toDistR
let inspect = (~env, dist) => C.inspect(dist)->run(~env)->toDistR let inspect = (~env, dist) => C.inspect(dist)->run(~env)->toDistR
let toString = (~env, dist) => C.toString(dist)->run(~env)->toStringR let toString = (~env, dist) => C.toString(dist)->run(~env)->toStringR
let toSparkline = (~env, dist, bucketCount) => C.toSparkline(dist, bucketCount)->run(~env)->toStringR let toSparkline = (~env, dist, bucketCount) =>
C.toSparkline(dist, bucketCount)->run(~env)->toStringR
let algebraicAdd = (~env, dist1, dist2) => C.algebraicAdd(dist1, dist2)->run(~env)->toDistR let algebraicAdd = (~env, dist1, dist2) => C.algebraicAdd(dist1, dist2)->run(~env)->toDistR
let algebraicMultiply = (~env, dist1, dist2) => let algebraicMultiply = (~env, dist1, dist2) =>
C.algebraicMultiply(dist1, dist2)->run(~env)->toDistR C.algebraicMultiply(dist1, dist2)->run(~env)->toDistR
@ -213,8 +217,7 @@ module Constructors = {
C.algebraicSubtract(dist1, dist2)->run(~env)->toDistR C.algebraicSubtract(dist1, dist2)->run(~env)->toDistR
let algebraicLogarithm = (~env, dist1, dist2) => let algebraicLogarithm = (~env, dist1, dist2) =>
C.algebraicLogarithm(dist1, dist2)->run(~env)->toDistR C.algebraicLogarithm(dist1, dist2)->run(~env)->toDistR
let algebraicExponentiate = (~env, dist1, dist2) => let algebraicPower = (~env, dist1, dist2) => C.algebraicPower(dist1, dist2)->run(~env)->toDistR
C.algebraicExponentiate(dist1, dist2)->run(~env)->toDistR
let pointwiseAdd = (~env, dist1, dist2) => C.pointwiseAdd(dist1, dist2)->run(~env)->toDistR let pointwiseAdd = (~env, dist1, dist2) => C.pointwiseAdd(dist1, dist2)->run(~env)->toDistR
let pointwiseMultiply = (~env, dist1, dist2) => let pointwiseMultiply = (~env, dist1, dist2) =>
C.pointwiseMultiply(dist1, dist2)->run(~env)->toDistR C.pointwiseMultiply(dist1, dist2)->run(~env)->toDistR
@ -223,6 +226,5 @@ module Constructors = {
C.pointwiseSubtract(dist1, dist2)->run(~env)->toDistR C.pointwiseSubtract(dist1, dist2)->run(~env)->toDistR
let pointwiseLogarithm = (~env, dist1, dist2) => let pointwiseLogarithm = (~env, dist1, dist2) =>
C.pointwiseLogarithm(dist1, dist2)->run(~env)->toDistR C.pointwiseLogarithm(dist1, dist2)->run(~env)->toDistR
let pointwiseExponentiate = (~env, dist1, dist2) => let pointwisePower = (~env, dist1, dist2) => C.pointwisePower(dist1, dist2)->run(~env)->toDistR
C.pointwiseExponentiate(dist1, dist2)->run(~env)->toDistR
} }

View File

@ -79,7 +79,7 @@ module Constructors: {
@genType @genType
let algebraicLogarithm: (~env: env, genericDist, genericDist) => result<genericDist, error> let algebraicLogarithm: (~env: env, genericDist, genericDist) => result<genericDist, error>
@genType @genType
let algebraicExponentiate: (~env: env, genericDist, genericDist) => result<genericDist, error> let algebraicPower: (~env: env, genericDist, genericDist) => result<genericDist, error>
@genType @genType
let pointwiseAdd: (~env: env, genericDist, genericDist) => result<genericDist, error> let pointwiseAdd: (~env: env, genericDist, genericDist) => result<genericDist, error>
@genType @genType
@ -91,5 +91,5 @@ module Constructors: {
@genType @genType
let pointwiseLogarithm: (~env: env, genericDist, genericDist) => result<genericDist, error> let pointwiseLogarithm: (~env: env, genericDist, genericDist) => result<genericDist, error>
@genType @genType
let pointwiseExponentiate: (~env: env, genericDist, genericDist) => result<genericDist, error> let pointwisePower: (~env: env, genericDist, genericDist) => result<genericDist, error>
} }

View File

@ -19,7 +19,7 @@ module Operation = {
| #Multiply | #Multiply
| #Subtract | #Subtract
| #Divide | #Divide
| #Exponentiate | #Power
| #Logarithm | #Logarithm
] ]
@ -28,7 +28,7 @@ module Operation = {
| #Add => \"+." | #Add => \"+."
| #Multiply => \"*." | #Multiply => \"*."
| #Subtract => \"-." | #Subtract => \"-."
| #Exponentiate => \"**" | #Power => \"**"
| #Divide => \"/." | #Divide => \"/."
| #Logarithm => (a, b) => log(a) /. log(b) | #Logarithm => (a, b) => log(a) /. log(b)
} }

View File

@ -2,17 +2,20 @@
type t = GenericDist_Types.genericDist type t = GenericDist_Types.genericDist
type error = GenericDist_Types.error type error = GenericDist_Types.error
type toPointSetFn = t => result<PointSetTypes.pointSetDist, error> type toPointSetFn = t => result<PointSetTypes.pointSetDist, error>
type toSampleSetFn = t => result<array<float>, error> type toSampleSetFn = t => result<SampleSetDist.t, error>
type scaleMultiplyFn = (t, float) => result<t, error> type scaleMultiplyFn = (t, float) => result<t, error>
type pointwiseAddFn = (t, t) => result<t, error> type pointwiseAddFn = (t, t) => result<t, error>
let sampleN = (t: t, n) => let sampleN = (t: t, n) =>
switch t { switch t {
| PointSet(r) => Ok(PointSetDist.sampleNRendered(n, r)) | PointSet(r) => PointSetDist.sampleNRendered(n, r)
| Symbolic(r) => Ok(SymbolicDist.T.sampleN(n, r)) | Symbolic(r) => SymbolicDist.T.sampleN(n, r)
| SampleSet(r) => Ok(SampleSet.sampleN(r, n)) | SampleSet(r) => SampleSetDist.sampleN(r, n)
} }
let toSampleSetDist = (t: t, n) =>
SampleSetDist.make(sampleN(t, n))->GenericDist_Types.Error.resultStringToResultError
let fromFloat = (f: float): t => Symbolic(SymbolicDist.Float.make(f)) let fromFloat = (f: float): t => Symbolic(SymbolicDist.Float.make(f))
let toString = (t: t) => let toString = (t: t) =>
@ -62,22 +65,16 @@ let toPointSet = (
switch (t: t) { switch (t: t) {
| PointSet(pointSet) => Ok(pointSet) | PointSet(pointSet) => Ok(pointSet)
| Symbolic(r) => Ok(SymbolicDist.T.toPointSetDist(~xSelection, xyPointLength, r)) | Symbolic(r) => Ok(SymbolicDist.T.toPointSetDist(~xSelection, xyPointLength, r))
| SampleSet(r) => { | SampleSet(r) =>
let response = SampleSet.toPointSetDist( SampleSetDist.toPointSetDist(
~samples=r, ~samples=r,
~samplingInputs={ ~samplingInputs={
sampleCount: sampleCount, sampleCount: sampleCount,
outputXYPoints: xyPointLength, outputXYPoints: xyPointLength,
pointSetDistLength: xyPointLength, pointSetDistLength: xyPointLength,
kernelWidth: None, kernelWidth: None,
}, },
(), )->GenericDist_Types.Error.resultStringToResultError
).pointSetDist
switch response {
| Some(r) => Ok(r)
| None => Error(Other("Converting sampleSet to pointSet failed"))
}
}
} }
} }
@ -91,7 +88,7 @@ let toSparkline = (t: t, ~sampleCount: int, ~bucketCount: int=20, unit): result<
t t
->toPointSet(~xSelection=#Linear, ~xyPointLength=bucketCount * 3, ~sampleCount, ()) ->toPointSet(~xSelection=#Linear, ~xyPointLength=bucketCount * 3, ~sampleCount, ())
->E.R.bind(r => ->E.R.bind(r =>
r->PointSetDist.toSparkline(bucketCount)->E.R2.errMap(r => Error(GenericDist_Types.Other(r))) r->PointSetDist.toSparkline(bucketCount)->GenericDist_Types.Error.resultStringToResultError
) )
module Truncate = { module Truncate = {
@ -166,10 +163,12 @@ module AlgebraicCombination = {
t1: t, t1: t,
t2: t, t2: t,
) => { ) => {
let arithmeticOperation = Operation.Algebraic.toFn(arithmeticOperation) let fn = Operation.Algebraic.toFn(arithmeticOperation)
E.R.merge(toSampleSet(t1), toSampleSet(t2))->E.R2.fmap(((a, b)) => { E.R.merge(toSampleSet(t1), toSampleSet(t2))
Belt.Array.zip(a, b)->E.A2.fmap(((a, b)) => arithmeticOperation(a, b)) ->E.R.bind(((t1, t2)) => {
SampleSetDist.map2(~fn, ~t1, ~t2)->GenericDist_Types.Error.resultStringToResultError
}) })
->E.R2.fmap(r => GenericDist_Types.SampleSet(r))
} }
//I'm (Ozzie) really just guessing here, very little idea what's best //I'm (Ozzie) really just guessing here, very little idea what's best
@ -200,13 +199,7 @@ module AlgebraicCombination = {
| Some(Error(e)) => Error(Other(e)) | Some(Error(e)) => Error(Other(e))
| None => | None =>
switch chooseConvolutionOrMonteCarlo(t1, t2) { switch chooseConvolutionOrMonteCarlo(t1, t2) {
| #CalculateWithMonteCarlo => | #CalculateWithMonteCarlo => runMonteCarlo(toSampleSetFn, arithmeticOperation, t1, t2)
runMonteCarlo(
toSampleSetFn,
arithmeticOperation,
t1,
t2,
)->E.R2.fmap(r => GenericDist_Types.SampleSet(r))
| #CalculateWithConvolution => | #CalculateWithConvolution =>
runConvolution( runConvolution(
toPointSetFn, toPointSetFn,
@ -247,7 +240,7 @@ let pointwiseCombinationFloat = (
): result<t, error> => { ): result<t, error> => {
let m = switch arithmeticOperation { let m = switch arithmeticOperation {
| #Add | #Subtract => Error(GenericDist_Types.DistributionVerticalShiftIsInvalid) | #Add | #Subtract => Error(GenericDist_Types.DistributionVerticalShiftIsInvalid)
| (#Multiply | #Divide | #Exponentiate | #Logarithm) as arithmeticOperation => | (#Multiply | #Divide | #Power | #Logarithm) as arithmeticOperation =>
toPointSetFn(t)->E.R2.fmap(t => { toPointSetFn(t)->E.R2.fmap(t => {
//TODO: Move to PointSet codebase //TODO: Move to PointSet codebase
let fn = (secondary, main) => Operation.Scale.toFn(arithmeticOperation, main, secondary) let fn = (secondary, main) => Operation.Scale.toFn(arithmeticOperation, main, secondary)
@ -272,7 +265,7 @@ let mixture = (
~pointwiseAddFn: pointwiseAddFn, ~pointwiseAddFn: pointwiseAddFn,
) => { ) => {
if E.A.length(values) == 0 { if E.A.length(values) == 0 {
Error(GenericDist_Types.Other("mixture must have at least 1 element")) Error(GenericDist_Types.Other("Mixture error: mixture must have at least 1 element"))
} else { } else {
let totalWeight = values->E.A2.fmap(E.Tuple2.second)->E.A.Floats.sum let totalWeight = values->E.A2.fmap(E.Tuple2.second)->E.A.Floats.sum
let properlyWeightedValues = let properlyWeightedValues =

View File

@ -1,11 +1,13 @@
type t = GenericDist_Types.genericDist type t = GenericDist_Types.genericDist
type error = GenericDist_Types.error type error = GenericDist_Types.error
type toPointSetFn = t => result<PointSetTypes.pointSetDist, error> type toPointSetFn = t => result<PointSetTypes.pointSetDist, error>
type toSampleSetFn = t => result<array<float>, error> type toSampleSetFn = t => result<SampleSetDist.t, error>
type scaleMultiplyFn = (t, float) => result<t, error> type scaleMultiplyFn = (t, float) => result<t, error>
type pointwiseAddFn = (t, t) => result<t, error> type pointwiseAddFn = (t, t) => result<t, error>
let sampleN: (t, int) => result<array<float>, error> let sampleN: (t, int) => array<float>
let toSampleSetDist: (t, int) => Belt.Result.t<QuriSquiggleLang.SampleSetDist.t, error>
let fromFloat: float => t let fromFloat: float => t

View File

@ -1,6 +1,6 @@
type genericDist = type genericDist =
| PointSet(PointSetTypes.pointSetDist) | PointSet(PointSetTypes.pointSetDist)
| SampleSet(SampleSet.t) | SampleSet(SampleSetDist.t)
| Symbolic(SymbolicDistTypes.symbolicDist) | Symbolic(SymbolicDistTypes.symbolicDist)
@genType @genType
@ -10,6 +10,15 @@ type error =
| DistributionVerticalShiftIsInvalid | DistributionVerticalShiftIsInvalid
| Other(string) | Other(string)
module Error = {
type t = error
let fromString = (s: string): t => Other(s)
let resultStringToResultError: result<'a, string> => result<'a, error> = n =>
n->E.R2.errMap(r => r->fromString->Error)
}
module Operation = { module Operation = {
type direction = type direction =
| Algebraic | Algebraic
@ -20,7 +29,7 @@ module Operation = {
| #Multiply | #Multiply
| #Subtract | #Subtract
| #Divide | #Divide
| #Exponentiate | #Power
| #Logarithm | #Logarithm
] ]
@ -29,7 +38,7 @@ module Operation = {
| #Add => \"+." | #Add => \"+."
| #Multiply => \"*." | #Multiply => \"*."
| #Subtract => \"-." | #Subtract => \"-."
| #Exponentiate => \"**" | #Power => \"**"
| #Divide => \"/." | #Divide => \"/."
| #Logarithm => (a, b) => log(a) /. log(b) | #Logarithm => (a, b) => log(a) /. log(b)
} }
@ -143,8 +152,8 @@ module Constructors = {
ToDistCombination(Algebraic, #Logarithm, #Dist(dist2)), ToDistCombination(Algebraic, #Logarithm, #Dist(dist2)),
dist1, dist1,
) )
let algebraicExponentiate = (dist1, dist2): t => FromDist( let algebraicPower = (dist1, dist2): t => FromDist(
ToDistCombination(Algebraic, #Exponentiate, #Dist(dist2)), ToDistCombination(Algebraic, #Power, #Dist(dist2)),
dist1, dist1,
) )
let pointwiseAdd = (dist1, dist2): t => FromDist( let pointwiseAdd = (dist1, dist2): t => FromDist(
@ -167,8 +176,8 @@ module Constructors = {
ToDistCombination(Pointwise, #Logarithm, #Dist(dist2)), ToDistCombination(Pointwise, #Logarithm, #Dist(dist2)),
dist1, dist1,
) )
let pointwiseExponentiate = (dist1, dist2): t => FromDist( let pointwisePower = (dist1, dist2): t => FromDist(
ToDistCombination(Pointwise, #Exponentiate, #Dist(dist2)), ToDistCombination(Pointwise, #Power, #Dist(dist2)),
dist1, dist1,
) )
} }

View File

@ -114,7 +114,7 @@ let combineShapesContinuousContinuous = (
| #Subtract => (m1, m2) => m1 -. m2 | #Subtract => (m1, m2) => m1 -. m2
| #Multiply => (m1, m2) => m1 *. m2 | #Multiply => (m1, m2) => m1 *. m2
| #Divide => (m1, mInv2) => m1 *. mInv2 | #Divide => (m1, mInv2) => m1 *. mInv2
| #Exponentiate => (m1, mInv2) => m1 ** mInv2 | #Power => (m1, mInv2) => m1 ** mInv2
| #Logarithm => (m1, m2) => log(m1) /. log(m2) | #Logarithm => (m1, m2) => log(m1) /. log(m2)
} // note: here, mInv2 = mean(1 / t2) ~= 1 / mean(t2) } // note: here, mInv2 = mean(1 / t2) ~= 1 / mean(t2)
@ -124,7 +124,7 @@ let combineShapesContinuousContinuous = (
| #Add => (v1, v2, _, _) => v1 +. v2 | #Add => (v1, v2, _, _) => v1 +. v2
| #Subtract => (v1, v2, _, _) => v1 +. v2 | #Subtract => (v1, v2, _, _) => v1 +. v2
| #Multiply => (v1, v2, m1, m2) => v1 *. v2 +. v1 *. m2 ** 2. +. v2 *. m1 ** 2. | #Multiply => (v1, v2, m1, m2) => v1 *. v2 +. v1 *. m2 ** 2. +. v2 *. m1 ** 2.
| #Exponentiate => (v1, v2, m1, m2) => v1 *. v2 +. v1 *. m2 ** 2. +. v2 *. m1 ** 2. | #Power => (v1, v2, m1, m2) => v1 *. v2 +. v1 *. m2 ** 2. +. v2 *. m1 ** 2.
| #Logarithm => (v1, v2, m1, m2) => v1 *. v2 +. v1 *. m2 ** 2. +. v2 *. m1 ** 2. | #Logarithm => (v1, v2, m1, m2) => v1 *. v2 +. v1 *. m2 ** 2. +. v2 *. m1 ** 2.
| #Divide => (v1, vInv2, m1, mInv2) => v1 *. vInv2 +. v1 *. mInv2 ** 2. +. vInv2 *. m1 ** 2. | #Divide => (v1, vInv2, m1, mInv2) => v1 *. vInv2 +. v1 *. mInv2 ** 2. +. vInv2 *. m1 ** 2.
} }
@ -233,7 +233,7 @@ let combineShapesContinuousDiscrete = (
() ()
} }
| #Multiply | #Multiply
| #Exponentiate | #Power
| #Logarithm | #Logarithm
| #Divide => | #Divide =>
for j in 0 to t2n - 1 { for j in 0 to t2n - 1 {

View File

@ -0,0 +1,68 @@
/*
This is used as a smart constructor. The only way to create a SampleSetDist.t is to call
this constructor.
https://stackoverflow.com/questions/66909578/how-to-make-a-type-constructor-private-in-rescript-except-in-current-module
*/
module T: {
//This really should be hidden (remove the array<float>). The reason it isn't is to act as an escape hatch in JS__Test.ts.
//When we get a good functional library in TS, we could refactor that out.
@genType
type t = array<float>
let make: array<float> => result<t, string>
let get: t => array<float>
} = {
type t = array<float>
let make = (a: array<float>) =>
if E.A.length(a) > 5 {
Ok(a)
} else {
Error("too small")
}
let get = (a: t) => a
}
include T
let length = (t: t) => get(t)->E.A.length
/*
TODO: Refactor to get a more precise estimate. Also, this code is just fairly messy, could use
some refactoring.
*/
let toPointSetDist = (~samples: t, ~samplingInputs: SamplingInputs.samplingInputs): result<
PointSetTypes.pointSetDist,
string,
> =>
SampleSetDist_ToPointSet.toPointSetDist(
~samples=get(samples),
~samplingInputs,
(),
).pointSetDist->E.O2.toResult("Failed to convert to PointSetDist")
//Randomly get one sample from the distribution
let sample = (t: t): float => {
let i = E.Int.random(~min=0, ~max=E.A.length(get(t)) - 1)
E.A.unsafe_get(get(t), i)
}
/*
If asked for a length of samples shorter or equal the length of the distribution,
return this first n samples of this distribution.
Else, return n random samples of the distribution.
The former helps in cases where multiple distributions are correlated.
However, if n > length(t), then there's no clear right answer, so we just randomly
sample everything.
*/
let sampleN = (t: t, n) => {
if n <= E.A.length(get(t)) {
E.A.slice(get(t), ~offset=0, ~len=n)
} else {
Belt.Array.makeBy(n, _ => sample(t))
}
}
//TODO: Figure out what to do if distributions are different lengths. ``zip`` is kind of inelegant for this.
let map2 = (~fn: (float, float) => float, ~t1: t, ~t2: t) => {
let samples = Belt.Array.zip(get(t1), get(t2))->E.A2.fmap(((a, b)) => fn(a, b))
make(samples)
}

View File

@ -1,4 +1,4 @@
//The math here was taken from https://github.com/jasondavies/science.js/blob/master/src/stats/bandwidth.js //The math here was taken from https://github.com/jasondavies/science.js/blob/master/src/stats/SampleSetDist_Bandwidth.js
let len = x => E.A.length(x) |> float_of_int let len = x => E.A.length(x) |> float_of_int

View File

@ -1,8 +1,3 @@
@genType
type t = array<float>
// TODO: Refactor to raise correct error when not enough samples
module Internals = { module Internals = {
module Types = { module Types = {
type samplingStats = { type samplingStats = {
@ -75,7 +70,7 @@ module Internals = {
let formatUnitWidth = w => Jstat.max([w, 1.0]) |> int_of_float let formatUnitWidth = w => Jstat.max([w, 1.0]) |> int_of_float
let suggestedUnitWidth = (samples, outputXYPoints) => { let suggestedUnitWidth = (samples, outputXYPoints) => {
let suggestedXWidth = Bandwidth.nrd0(samples) let suggestedXWidth = SampleSetDist_Bandwidth.nrd0(samples)
xWidthToUnitWidth(samples, outputXYPoints, suggestedXWidth) xWidthToUnitWidth(samples, outputXYPoints, suggestedXWidth)
} }
@ -102,7 +97,7 @@ let toPointSetDist = (
let pdf = let pdf =
continuousPart |> E.A.length > 5 continuousPart |> E.A.length > 5
? { ? {
let _suggestedXWidth = Bandwidth.nrd0(continuousPart) let _suggestedXWidth = SampleSetDist_Bandwidth.nrd0(continuousPart)
// todo: This does some recalculating from the last step. // todo: This does some recalculating from the last step.
let _suggestedUnitWidth = Internals.T.suggestedUnitWidth( let _suggestedUnitWidth = Internals.T.suggestedUnitWidth(
continuousPart, continuousPart,
@ -145,25 +140,3 @@ let toPointSetDist = (
samplesParse samplesParse
} }
//Randomly get one sample from the distribution
let sample = (t: t): float => {
let i = E.Int.random(~min=0, ~max=E.A.length(t) - 1)
E.A.unsafe_get(t, i)
}
/*
If asked for a length of samples shorter or equal the length of the distribution,
return this first n samples of this distribution.
Else, return n random samples of the distribution.
The former helps in cases where multiple distributions are correlated.
However, if n > length(t), then there's no clear right answer, so we just randomly
sample everything.
*/
let sampleN = (t: t, n) => {
if n <= E.A.length(t) {
E.A.slice(t, ~offset=0, ~len=n)
} else {
Belt.Array.makeBy(n, _ => sample(t))
}
}

View File

@ -118,7 +118,7 @@ module PointwiseCombination = {
switch pointwiseOp { switch pointwiseOp {
| #Add => pointwiseAdd(evaluationParams, t1, t2) | #Add => pointwiseAdd(evaluationParams, t1, t2)
| #Multiply => pointwiseCombine(\"*.", evaluationParams, t1, t2) | #Multiply => pointwiseCombine(\"*.", evaluationParams, t1, t2)
| #Exponentiate => pointwiseCombine(\"**", evaluationParams, t1, t2) | #Power => pointwiseCombine(\"**", evaluationParams, t1, t2)
} }
} }

View File

@ -218,15 +218,14 @@ module SamplingDistribution = {
algebraicOp, algebraicOp,
a, a,
b, b,
) ) |> E.O.toResult("Could not get samples")
let pointSetDist = let sampleSetDist = samples -> E.R.bind(SampleSetDist.make)
samples
|> E.O.fmap(r => let pointSetDist =
SampleSet.toPointSetDist(~samplingInputs=evaluationParams.samplingInputs, ~samples=r, ()) sampleSetDist
) -> E.R.bind(r =>
|> E.O.bind(_, r => r.pointSetDist) SampleSetDist.toPointSetDist(~samplingInputs=evaluationParams.samplingInputs, ~samples=r));
|> E.O.toResult("No response")
pointSetDist |> E.R.fmap(r => #Normalize(#RenderedDist(r))) pointSetDist |> E.R.fmap(r => #Normalize(#RenderedDist(r)))
}) })
} }

View File

@ -227,7 +227,7 @@ let all = [
}, },
(), (),
), ),
makeRenderedDistFloat("scaleExp", (dist, float) => verticalScaling(#Exponentiate, dist, float)), makeRenderedDistFloat("scaleExp", (dist, float) => verticalScaling(#Power, dist, float)),
makeRenderedDistFloat("scaleMultiply", (dist, float) => verticalScaling(#Multiply, dist, float)), makeRenderedDistFloat("scaleMultiply", (dist, float) => verticalScaling(#Multiply, dist, float)),
makeRenderedDistFloat("scaleLog", (dist, float) => verticalScaling(#Logarithm, dist, float)), makeRenderedDistFloat("scaleLog", (dist, float) => verticalScaling(#Logarithm, dist, float)),
Multimodal._function, Multimodal._function,

View File

@ -144,11 +144,11 @@ module MathAdtToDistDst = {
| ("subtract", _) => Error("Subtraction needs two operands") | ("subtract", _) => Error("Subtraction needs two operands")
| ("multiply", [l, r]) => toOkAlgebraic((#Multiply, l, r)) | ("multiply", [l, r]) => toOkAlgebraic((#Multiply, l, r))
| ("multiply", _) => Error("Multiplication needs two operands") | ("multiply", _) => Error("Multiplication needs two operands")
| ("pow", [l, r]) => toOkAlgebraic((#Exponentiate, l, r)) | ("pow", [l, r]) => toOkAlgebraic((#Power, l, r))
| ("pow", _) => Error("Exponentiation needs two operands") | ("pow", _) => Error("Exponentiation needs two operands")
| ("dotMultiply", [l, r]) => toOkPointwise((#Multiply, l, r)) | ("dotMultiply", [l, r]) => toOkPointwise((#Multiply, l, r))
| ("dotMultiply", _) => Error("Dotwise multiplication needs two operands") | ("dotMultiply", _) => Error("Dotwise multiplication needs two operands")
| ("dotPow", [l, r]) => toOkPointwise((#Exponentiate, l, r)) | ("dotPow", [l, r]) => toOkPointwise((#Power, l, r))
| ("dotPow", _) => Error("Dotwise exponentiation needs two operands") | ("dotPow", _) => Error("Dotwise exponentiation needs two operands")
| ("rightLogShift", [l, r]) => toOkPointwise((#Add, l, r)) | ("rightLogShift", [l, r]) => toOkPointwise((#Add, l, r))
| ("rightLogShift", _) => Error("Dotwise addition needs two operands") | ("rightLogShift", _) => Error("Dotwise addition needs two operands")

View File

@ -18,8 +18,8 @@ module Helpers = {
| "divide" => #Divide | "divide" => #Divide
| "log" => #Logarithm | "log" => #Logarithm
| "dotDivide" => #Divide | "dotDivide" => #Divide
| "pow" => #Exponentiate | "pow" => #Power
| "dotPow" => #Exponentiate | "dotPow" => #Power
| "multiply" => #Multiply | "multiply" => #Multiply
| "dotMultiply" => #Multiply | "dotMultiply" => #Multiply
| "dotLog" => #Logarithm | "dotLog" => #Logarithm

View File

@ -22,3 +22,6 @@ type resultDist = result<genericDist, error>
type resultFloat = result<float, error> type resultFloat = result<float, error>
@genType @genType
type resultString = result<string, error> type resultString = result<string, error>
@genType
let makeSampleSetDist = SampleSetDist.make

View File

@ -6,12 +6,12 @@ type algebraicOperation = [
| #Multiply | #Multiply
| #Subtract | #Subtract
| #Divide | #Divide
| #Exponentiate | #Power
| #Logarithm | #Logarithm
] ]
@genType @genType
type pointwiseOperation = [#Add | #Multiply | #Exponentiate] type pointwiseOperation = [#Add | #Multiply | #Power]
type scaleOperation = [#Multiply | #Exponentiate | #Logarithm | #Divide] type scaleOperation = [#Multiply | #Power | #Logarithm | #Divide]
type distToFloatOperation = [ type distToFloatOperation = [
| #Pdf(float) | #Pdf(float)
| #Cdf(float) | #Cdf(float)
@ -27,7 +27,7 @@ module Algebraic = {
| #Add => \"+." | #Add => \"+."
| #Subtract => \"-." | #Subtract => \"-."
| #Multiply => \"*." | #Multiply => \"*."
| #Exponentiate => \"**" | #Power => \"**"
| #Divide => \"/." | #Divide => \"/."
| #Logarithm => (a, b) => log(a) /. log(b) | #Logarithm => (a, b) => log(a) /. log(b)
} }
@ -43,7 +43,7 @@ module Algebraic = {
| #Add => "+" | #Add => "+"
| #Subtract => "-" | #Subtract => "-"
| #Multiply => "*" | #Multiply => "*"
| #Exponentiate => "**" | #Power => "**"
| #Divide => "/" | #Divide => "/"
| #Logarithm => "log" | #Logarithm => "log"
} }
@ -56,7 +56,7 @@ module Pointwise = {
let toString = x => let toString = x =>
switch x { switch x {
| #Add => "+" | #Add => "+"
| #Exponentiate => "^" | #Power => "**"
| #Multiply => "*" | #Multiply => "*"
} }
@ -83,7 +83,7 @@ module Scale = {
switch x { switch x {
| #Multiply => \"*." | #Multiply => \"*."
| #Divide => \"/." | #Divide => \"/."
| #Exponentiate => \"**" | #Power => \"**"
| #Logarithm => (a, b) => log(a) /. log(b) | #Logarithm => (a, b) => log(a) /. log(b)
} }
@ -91,7 +91,7 @@ module Scale = {
switch operation { switch operation {
| #Multiply => j`verticalMultiply($value, $scaleBy) ` | #Multiply => j`verticalMultiply($value, $scaleBy) `
| #Divide => j`verticalDivide($value, $scaleBy) ` | #Divide => j`verticalDivide($value, $scaleBy) `
| #Exponentiate => j`verticalExponentiate($value, $scaleBy) ` | #Power => j`verticalPower($value, $scaleBy) `
| #Logarithm => j`verticalLog($value, $scaleBy) ` | #Logarithm => j`verticalLog($value, $scaleBy) `
} }
@ -99,7 +99,7 @@ module Scale = {
switch x { switch x {
| #Multiply => (a, b) => Some(a *. b) | #Multiply => (a, b) => Some(a *. b)
| #Divide => (a, b) => Some(a /. b) | #Divide => (a, b) => Some(a /. b)
| #Exponentiate => (_, _) => None | #Power => (_, _) => None
| #Logarithm => (_, _) => None | #Logarithm => (_, _) => None
} }
@ -107,7 +107,7 @@ module Scale = {
switch x { switch x {
| #Multiply => (_, _) => None // TODO: this could probably just be multiplied out (using Continuous.scaleBy) | #Multiply => (_, _) => None // TODO: this could probably just be multiplied out (using Continuous.scaleBy)
| #Divide => (_, _) => None | #Divide => (_, _) => None
| #Exponentiate => (_, _) => None | #Power => (_, _) => None
| #Logarithm => (_, _) => None | #Logarithm => (_, _) => None
} }
} }