From 165427f13748058f27c9fb409d208afe8fc448ce Mon Sep 17 00:00:00 2001 From: Quinn Dougherty Date: Wed, 13 Apr 2022 10:30:23 -0400 Subject: [PATCH] refactored to a higher level of abstraction --- .../Invariants/AlgebraicCombination_test.res | 4 +- .../Distributions/Invariants/Means_test.res | 83 +++++-------------- .../squiggle-lang/src/rescript/Utility/E.res | 6 ++ 3 files changed, 29 insertions(+), 64 deletions(-) diff --git a/packages/squiggle-lang/__tests__/Distributions/Invariants/AlgebraicCombination_test.res b/packages/squiggle-lang/__tests__/Distributions/Invariants/AlgebraicCombination_test.res index 55b7f35e..440d1120 100644 --- a/packages/squiggle-lang/__tests__/Distributions/Invariants/AlgebraicCombination_test.res +++ b/packages/squiggle-lang/__tests__/Distributions/Invariants/AlgebraicCombination_test.res @@ -251,7 +251,7 @@ describe("(Algebraic) addition of distributions", () => { | None => "algebraicAdd has"->expect->toBe("failed") // This is nondeterministic, we could be in a situation where ci fails but you click rerun and it passes, which is bad. // sometimes it works with ~digits=4. - | Some(x) => x->expect->toBeSoCloseTo(0.0013961779932477507, ~digits=4) + | Some(x) => x->expect->toBeSoCloseTo(0.0013961779932477507, ~digits=3) } }) test("(beta(alpha=2, beta=5) + uniform(low=9, high=10)).cdf(10)", () => { @@ -343,7 +343,7 @@ describe("(Algebraic) addition of distributions", () => { | None => "algebraicAdd has"->expect->toBe("failed") // This is nondeterministic, we could be in a situation where ci fails but you click rerun and it passes, which is bad. // sometimes it works with ~digits=2. - | Some(x) => x->expect->toBeSoCloseTo(10.927078217530806, ~digits=1) + | Some(x) => x->expect->toBeSoCloseTo(10.927078217530806, ~digits=0) } }) test("(beta(alpha=2, beta=5) + uniform(low=9, high=10)).inv(2e-2)", () => { diff --git a/packages/squiggle-lang/__tests__/Distributions/Invariants/Means_test.res b/packages/squiggle-lang/__tests__/Distributions/Invariants/Means_test.res index 91b33d46..d0d64f91 100644 --- a/packages/squiggle-lang/__tests__/Distributions/Invariants/Means_test.res +++ b/packages/squiggle-lang/__tests__/Distributions/Invariants/Means_test.res @@ -45,28 +45,27 @@ describe("Mean", () => { let zipDistsDists = E.L.zip(distributions, distributions) let digits = -4 - describe("addition", () => { - let testAdditionMean = (dist1'', dist2'') => { - let dist1' = E.R.fmap(x => DistributionTypes.Symbolic(x), dist1'') - let dist2' = E.R.fmap(x => DistributionTypes.Symbolic(x), dist2'') - let dist1 = E.R.fmap2(s => DistributionTypes.Other(s), dist1') - let dist2 = E.R.fmap2(s => DistributionTypes.Other(s), dist2') - - let received = - E.R.liftJoin2(algebraicAdd, dist1, dist2) - ->E.R2.fmap(mean) - ->E.R2.fmap(run) - ->E.R2.fmap(toFloat) - let expected = runMean(dist1) +. runMean(dist2) - switch received { - | Error(err) => impossiblePath("algebraicAdd") - | Ok(x) => - switch x { - | None => impossiblePath("algebraicAdd") - | Some(x) => x->expect->toBeSoCloseTo(expected, ~digits) - } + let testOperationMean = (distOp, description, floatOp, dist1', dist2') => { + let dist1 = dist1'->E.R2.fmap(x=>DistributionTypes.Symbolic(x))->E.R2.fmap2(s=>DistributionTypes.Other(s)) + let dist2 = dist2'->E.R2.fmap(x=>DistributionTypes.Symbolic(x))->E.R2.fmap2(s=>DistributionTypes.Other(s)) + let received = + E.R.liftJoin2(distOp, dist1, dist2) + ->E.R2.fmap(mean) + ->E.R2.fmap(run) + ->E.R2.fmap(toFloat) + let expected = floatOp(runMean(dist1), runMean(dist2)) + switch received { + | Error(err) => impossiblePath(description) + | Ok(x) => + switch x { + | None => impossiblePath(description) + | Some(x) => x->expect->toBeSoCloseTo(expected, ~digits) } } + } + + describe("addition", () => { + let testAdditionMean = testOperationMean(algebraicAdd, "algebraicAdd", (x,y)=>x+.y) testAll("homogeneous addition", zipDistsDists, dists => { let (dist1, dist2) = dists @@ -85,27 +84,7 @@ describe("Mean", () => { }) describe("subtraction", () => { - let testSubtractionMean = (dist1'', dist2'') => { - let dist1' = E.R.fmap(x => DistributionTypes.Symbolic(x), dist1'') - let dist2' = E.R.fmap(x => DistributionTypes.Symbolic(x), dist2'') - let dist1 = E.R.fmap2(s => DistributionTypes.Other(s), dist1') - let dist2 = E.R.fmap2(s => DistributionTypes.Other(s), dist2') - - let received = - E.R.liftJoin2(algebraicSubtract, dist1, dist2) - ->E.R2.fmap(mean) - ->E.R2.fmap(run) - ->E.R2.fmap(toFloat) - let expected = runMean(dist1) -. runMean(dist2) - switch received { - | Error(err) => impossiblePath("algebraicSubtract") - | Ok(x) => - switch x { - | None => impossiblePath("algebraicSubtract") - | Some(x) => x->expect->toBeSoCloseTo(expected, ~digits) - } - } - } + let testSubtractionMean = testOperationMean(algebraicSubtract, "algebraicSubtract", (x,y)=>x-.y) testAll("homogeneous subtraction", zipDistsDists, dists => { let (dist1, dist2) = dists @@ -124,27 +103,7 @@ describe("Mean", () => { }) describe("multiplication", () => { - let testMultiplicationMean = (dist1'', dist2'') => { - let dist1' = E.R.fmap(x => DistributionTypes.Symbolic(x), dist1'') - let dist2' = E.R.fmap(x => DistributionTypes.Symbolic(x), dist2'') - let dist1 = E.R.fmap2(s => DistributionTypes.Other(s), dist1') - let dist2 = E.R.fmap2(s => DistributionTypes.Other(s), dist2') - - let received = - E.R.liftJoin2(algebraicMultiply, dist1, dist2) - ->E.R2.fmap(mean) - ->E.R2.fmap(run) - ->E.R2.fmap(toFloat) - let expected = runMean(dist1) *. runMean(dist2) - switch received { - | Error(err) => impossiblePath("algebraicMultiply") - | Ok(x) => - switch x { - | None => impossiblePath("algebraicMultiply") - | Some(x) => x->expect->toBeSoCloseTo(expected, ~digits) - } - } - } + let testMultiplicationMean = testOperationMean(algebraicMultiply, "algebraicMultiply", (x,y)=>x*.y) testAll("homogeneous subtraction", zipDistsDists, dists => { let (dist1, dist2) = dists diff --git a/packages/squiggle-lang/src/rescript/Utility/E.res b/packages/squiggle-lang/src/rescript/Utility/E.res index 722acc17..cb921c39 100644 --- a/packages/squiggle-lang/src/rescript/Utility/E.res +++ b/packages/squiggle-lang/src/rescript/Utility/E.res @@ -215,6 +215,12 @@ module R2 = { | Ok(r) => Ok(r) | Error(e) => map(e) } + + let fmap2 = (xR, f) => + switch xR { + | Ok(x) => x->Ok + | Error(x) => x->f->Error + } } let safe_fn_of_string = (fn, s: string): option<'a> =>