From 3cf336d7203e9440f60a88a22912ebf96a611b60 Mon Sep 17 00:00:00 2001 From: Ozzie Gooen Date: Mon, 28 Mar 2022 15:14:39 -0400 Subject: [PATCH] Starting to add tests to rescript --- .../GenericDist/GenericOperation__Test.res | 104 ++++++++++++++++++ .../GenericDist_GenericOperation.res | 53 ++++++++- .../GenericDist_GenericOperation.resi | 16 ++- .../GenericDist/GenericDist_Types.res | 7 ++ .../src/rescript/pointSetDist/XYShape.res | 2 +- .../src/rescript/sampleSet/SampleSet.res | 2 + 6 files changed, 178 insertions(+), 6 deletions(-) create mode 100644 packages/squiggle-lang/__tests__/GenericDist/GenericOperation__Test.res diff --git a/packages/squiggle-lang/__tests__/GenericDist/GenericOperation__Test.res b/packages/squiggle-lang/__tests__/GenericDist/GenericOperation__Test.res new file mode 100644 index 00000000..2c1a44c4 --- /dev/null +++ b/packages/squiggle-lang/__tests__/GenericDist/GenericOperation__Test.res @@ -0,0 +1,104 @@ +open Jest +open Expect + +let params: GenericDist_GenericOperation.params = { + sampleCount: 100, + xyPointLength: 100, +} + +let normalDist: GenericDist_Types.genericDist = #Symbolic(#Normal({mean: 5.0, stdev: 2.0})) +let normalDist10: GenericDist_Types.genericDist = #Symbolic(#Normal({mean: 10.0, stdev: 2.0})) +let normalDist20: GenericDist_Types.genericDist = #Symbolic(#Normal({mean: 20.0, stdev: 2.0})) +let uniformDist: GenericDist_Types.genericDist = #Symbolic(#Uniform({low: 9.0, high: 10.0})) + +let {toFloat, toDist, toString, toError} = module(GenericDist_GenericOperation.Output) +let {run, fmap} = module(GenericDist_GenericOperation) +let run = run(params) +let fmap = fmap(params) +let toExt: option<'a> => 'a = E.O.toExt( + "Should be impossible to reach (This error is in test file)", +) + +describe("normalize", () => { + test("has no impact on normal dist", () => { + let result = run(#fromDist(#toDist(#normalize), normalDist)) + expect(result)->toEqual(#Dist(normalDist)) + }) +}) + +describe("mean", () => { + test("for a normal distribution", () => { + let result = GenericDist_GenericOperation.run(params, #fromDist(#toFloat(#Mean), normalDist)) + expect(result)->toEqual(#Float(5.0)) + }) +}) + +describe("mixture", () => { + test("on two normal distributions", () => { + let result = + run(#mixture([(normalDist10, 0.5), (normalDist20, 0.5)])) + |> fmap(#fromDist(#toFloat(#Mean))) + |> toFloat + |> toExt + expect(result)->toBeCloseTo(15.28) + }) +}) + +describe("toPointSet", () => { + test("on symbolic normal distribution", () => { + let result = + run(#fromDist(#toDist(#toPointSet), normalDist)) + |> fmap(#fromDist(#toFloat(#Mean))) + |> toFloat + |> toExt + expect(result)->toBeCloseTo(5.09) + }) + + test("on sample set distribution with under 4 points", () => { + let result = + run(#fromDist(#toDist(#toPointSet), #SampleSet([0.0, 1.0, 2.0, 3.0]))) |> fmap( + #fromDist(#toFloat(#Mean)), + ) + expect(result)->toEqual(#Error(Other("Converting sampleSet to pointSet failed"))) + }) + + test("back and forth", () => { + let result = + run(#fromDist(#toDist(#toPointSet), normalDist)) + |> fmap(#fromDist(#toDist(#toSampleSet(1000)))) + |> fmap(#fromDist(#toDist(#consoleLog))) + |> fmap(#fromDist(#toDist(#toPointSet))) + |> fmap(#fromDist(#toDist(#consoleLog))) + |> fmap(#fromDist(#toFloat(#Mean))) + |> toFloat + |> toExt + expect(result)->toBeCloseTo(5.09) + }) + + test("on sample set distribution", () => { + let result = + run( + #fromDist( + #toDist(#toPointSet), + #SampleSet([ + 0.0, + 1.0, + 2.0, + 3.0, + 4.0, + 5.0, + 6.0, + 7.0, + 8.0, + 9.0, + 10.0, + ]), + ), + ) + |> fmap(#fromDist(#toFloat(#Mean))) + |> toFloat + |> toExt + Js.log(result) + expect(result)->toBeCloseTo(5.09) + }) +}) diff --git a/packages/squiggle-lang/src/rescript/GenericDist/GenericDist_GenericOperation.res b/packages/squiggle-lang/src/rescript/GenericDist/GenericDist_GenericOperation.res index cc542ef3..9993ef20 100644 --- a/packages/squiggle-lang/src/rescript/GenericDist/GenericDist_GenericOperation.res +++ b/packages/squiggle-lang/src/rescript/GenericDist/GenericDist_GenericOperation.res @@ -16,6 +16,32 @@ type outputType = [ | #String(string) ] +module Output = { + let toDist = (o: outputType) => + switch o { + | #Dist(d) => Some(d) + | _ => None + } + + let toFloat = (o: outputType) => + switch o { + | #Float(d) => Some(d) + | _ => None + } + + let toString = (o: outputType) => + switch o { + | #String(d) => Some(d) + | _ => None + } + + let toError = (o: outputType) => + switch o { + | #Error(d) => Some(d) + | _ => None + } +} + let fromResult = (r: result): outputType => switch r { | Ok(o) => o @@ -71,12 +97,13 @@ let rec run = (extra, fnName: operation): outputType => { |> E.R.fmap(r => #Float(r)) |> fromResult | #toString => dist |> GenericDist.toString |> (r => #String(r)) + | #toDist(#consoleLog) => { + Js.log2("Console log requested: ", dist) + #Dist(dist) + } | #toDist(#normalize) => dist |> GenericDist.normalize |> (r => #Dist(r)) | #toDist(#truncate(left, right)) => - dist - |> GenericDist.truncate(toPointSet, left, right) - |> E.R.fmap(r => #Dist(r)) - |> fromResult + dist |> GenericDist.truncate(toPointSet, left, right) |> E.R.fmap(r => #Dist(r)) |> fromResult | #toDist(#toPointSet) => dist |> GenericDist.toPointSet(xyPointLength) @@ -109,3 +136,21 @@ let rec run = (extra, fnName: operation): outputType => { GenericDist.mixture(scaleMultiply, pointwiseAdd, dists) |> E.R.fmap(r => #Dist(r)) |> fromResult } } + +let runFromDist = (extra, fnName, dist) => run(extra, #fromDist(fnName, dist)) +let runFromFloat = (extra, fnName, float) => run(extra, #fromFloat(fnName, float)) + +let fmap = ( + extra, + fn: GenericDist_Types.Operation.singleParamaterFunction, + input: outputType, +): outputType => { + let newFnCall: result = switch (fn, input) { + | (#fromDist(fromDist), #Dist(o)) => Ok(#fromDist(fromDist, o)) + | (#fromFloat(fromDist), #Float(o)) => Ok(#fromFloat(fromDist, o)) + | (_, #Error(r)) => Error(r) + | (#fromDist(_), _) => Error(Other("Expected dist, got something else")) + | (#fromFloat(_), _) => Error(Other("Expected float, got something else")) + } + newFnCall |> E.R.fmap(r => run(extra, r)) |> fromResult +} diff --git a/packages/squiggle-lang/src/rescript/GenericDist/GenericDist_GenericOperation.resi b/packages/squiggle-lang/src/rescript/GenericDist/GenericDist_GenericOperation.resi index 84822876..d05f4ddc 100644 --- a/packages/squiggle-lang/src/rescript/GenericDist/GenericDist_GenericOperation.resi +++ b/packages/squiggle-lang/src/rescript/GenericDist/GenericDist_GenericOperation.resi @@ -10,4 +10,18 @@ type outputType = [ | #String(string) ] -let run: (params, GenericDist_Types.Operation.genericFunctionCall) => outputType \ No newline at end of file +let run: (params, GenericDist_Types.Operation.genericFunctionCall) => outputType +let runFromDist: ( + params, + GenericDist_Types.Operation.fromDist, + GenericDist_Types.genericDist, +) => outputType +let runFromFloat: (params, GenericDist_Types.Operation.fromDist, float) => outputType +let fmap: (params, GenericDist_Types.Operation.singleParamaterFunction, outputType) => outputType + +module Output: { + let toDist: outputType => option + let toFloat: outputType => option + let toString: outputType => option + let toError: outputType => option +} diff --git a/packages/squiggle-lang/src/rescript/GenericDist/GenericDist_Types.res b/packages/squiggle-lang/src/rescript/GenericDist/GenericDist_Types.res index ecf8398f..7aad46b8 100644 --- a/packages/squiggle-lang/src/rescript/GenericDist/GenericDist_Types.res +++ b/packages/squiggle-lang/src/rescript/GenericDist/GenericDist_Types.res @@ -48,6 +48,7 @@ module Operation = { | #toPointSet | #toSampleSet(int) | #truncate(option, option) + | #consoleLog ] type toFloatArray = [ @@ -61,6 +62,11 @@ module Operation = { | #toString ] + type singleParamaterFunction = [ + | #fromDist(fromDist) + | #fromFloat(fromDist) + ] + type genericFunctionCall = [ | #fromDist(fromDist, genericDist) | #fromFloat(fromDist, float) @@ -78,6 +84,7 @@ module Operation = { | #toDist(#toPointSet) => `toPointSet` | #toDist(#toSampleSet(r)) => `toSampleSet${E.I.toString(r)}` | #toDist(#truncate(_, _)) => `truncate` + | #toDist(#consoleLog) => `consoleLog` | #toString => `toString` | #toDistCombination(#Algebraic, _, _) => `algebraic` | #toDistCombination(#Pointwise, _, _) => `pointwise` diff --git a/packages/squiggle-lang/src/rescript/pointSetDist/XYShape.res b/packages/squiggle-lang/src/rescript/pointSetDist/XYShape.res index 048b571f..6cadec60 100644 --- a/packages/squiggle-lang/src/rescript/pointSetDist/XYShape.res +++ b/packages/squiggle-lang/src/rescript/pointSetDist/XYShape.res @@ -254,7 +254,7 @@ module PointwiseCombination = { j = t2n; continue; } else { - console.log("Error!", i, j); + console.log("PointwiseCombination Error", i, j); } outX.push(x); diff --git a/packages/squiggle-lang/src/rescript/sampleSet/SampleSet.res b/packages/squiggle-lang/src/rescript/sampleSet/SampleSet.res index 6b20d946..746f13d7 100644 --- a/packages/squiggle-lang/src/rescript/sampleSet/SampleSet.res +++ b/packages/squiggle-lang/src/rescript/sampleSet/SampleSet.res @@ -1,3 +1,5 @@ +// TODO: Refactor to raise correct error when not enough samples + module Internals = { module Types = { type samplingStats = {