From 4338f482ef37d9d5e0017c3a7e44c74e9c79723e Mon Sep 17 00:00:00 2001 From: Ozzie Gooen Date: Sat, 9 Apr 2022 21:24:44 -0400 Subject: [PATCH] Added genType to SampleSetDist to make pass tests, other minor fixes --- .../DistributionOperation.res | 14 ++++---- .../Distributions/GenericDist/GenericDist.res | 35 +++++++------------ .../GenericDist/GenericDist.resi | 3 +- .../SampleSetDist/SampleSetDist.res | 3 +- .../src/rescript/OldInterpreter/ASTTypes.res | 2 +- 5 files changed, 26 insertions(+), 31 deletions(-) diff --git a/packages/squiggle-lang/src/rescript/Distributions/DistributionOperation/DistributionOperation.res b/packages/squiggle-lang/src/rescript/Distributions/DistributionOperation/DistributionOperation.res index bb5f4f1a..71776f61 100644 --- a/packages/squiggle-lang/src/rescript/Distributions/DistributionOperation/DistributionOperation.res +++ b/packages/squiggle-lang/src/rescript/Distributions/DistributionOperation/DistributionOperation.res @@ -128,7 +128,10 @@ let rec run = (~env, functionCallInfo: functionCallInfo): outputType => { ->E.R2.fmap(r => Dist(r)) ->OutputLocal.fromResult | ToDist(ToSampleSet(n)) => - dist->GenericDist.toSampleSetDist(n)->E.R2.fmap(r => Dist(SampleSet(r)))->OutputLocal.fromResult + dist + ->GenericDist.toSampleSetDist(n) + ->E.R2.fmap(r => Dist(SampleSet(r))) + ->OutputLocal.fromResult | ToDist(ToPointSet) => dist ->GenericDist.toPointSet(~xyPointLength, ~sampleCount, ()) @@ -204,7 +207,8 @@ module Constructors = { C.truncate(dist, leftCutoff, rightCutoff)->run(~env)->toDistR let inspect = (~env, dist) => C.inspect(dist)->run(~env)->toDistR let toString = (~env, dist) => C.toString(dist)->run(~env)->toStringR - let toSparkline = (~env, dist, bucketCount) => C.toSparkline(dist, bucketCount)->run(~env)->toStringR + let toSparkline = (~env, dist, bucketCount) => + C.toSparkline(dist, bucketCount)->run(~env)->toStringR let algebraicAdd = (~env, dist1, dist2) => C.algebraicAdd(dist1, dist2)->run(~env)->toDistR let algebraicMultiply = (~env, dist1, dist2) => C.algebraicMultiply(dist1, dist2)->run(~env)->toDistR @@ -213,8 +217,7 @@ module Constructors = { C.algebraicSubtract(dist1, dist2)->run(~env)->toDistR let algebraicLogarithm = (~env, dist1, dist2) => C.algebraicLogarithm(dist1, dist2)->run(~env)->toDistR - let algebraicPower = (~env, dist1, dist2) => - C.algebraicPower(dist1, dist2)->run(~env)->toDistR + let algebraicPower = (~env, dist1, dist2) => C.algebraicPower(dist1, dist2)->run(~env)->toDistR let pointwiseAdd = (~env, dist1, dist2) => C.pointwiseAdd(dist1, dist2)->run(~env)->toDistR let pointwiseMultiply = (~env, dist1, dist2) => C.pointwiseMultiply(dist1, dist2)->run(~env)->toDistR @@ -223,6 +226,5 @@ module Constructors = { C.pointwiseSubtract(dist1, dist2)->run(~env)->toDistR let pointwiseLogarithm = (~env, dist1, dist2) => C.pointwiseLogarithm(dist1, dist2)->run(~env)->toDistR - let pointwisePower = (~env, dist1, dist2) => - C.pointwisePower(dist1, dist2)->run(~env)->toDistR + let pointwisePower = (~env, dist1, dist2) => C.pointwisePower(dist1, dist2)->run(~env)->toDistR } diff --git a/packages/squiggle-lang/src/rescript/Distributions/GenericDist/GenericDist.res b/packages/squiggle-lang/src/rescript/Distributions/GenericDist/GenericDist.res index 772e83f7..b14ea27f 100644 --- a/packages/squiggle-lang/src/rescript/Distributions/GenericDist/GenericDist.res +++ b/packages/squiggle-lang/src/rescript/Distributions/GenericDist/GenericDist.res @@ -5,14 +5,16 @@ type toPointSetFn = t => result type toSampleSetFn = t => result type scaleMultiplyFn = (t, float) => result type pointwiseAddFn = (t, t) => result + let sampleN = (t: t, n) => switch t { - | PointSet(r) => Ok(PointSetDist.sampleNRendered(n, r)) - | Symbolic(r) => Ok(SymbolicDist.T.sampleN(n, r)) - | SampleSet(r) => Ok(SampleSetDist.sampleN(r, n)) + | PointSet(r) => PointSetDist.sampleNRendered(n, r) + | Symbolic(r) => SymbolicDist.T.sampleN(n, r) + | SampleSet(r) => SampleSetDist.sampleN(r, n) } + let toSampleSetDist = (t: t, n) => - sampleN(t, n)->E.R.bind(SampleSetDist.make)->GenericDist_Types.Error.resultStringToResultError + SampleSetDist.make(sampleN(t, n))->GenericDist_Types.Error.resultStringToResultError let fromFloat = (f: float): t => Symbolic(SymbolicDist.Float.make(f)) @@ -72,7 +74,6 @@ let toPointSet = ( pointSetDistLength: xyPointLength, kernelWidth: None, }, - (), )->GenericDist_Types.Error.resultStringToResultError } } @@ -162,14 +163,12 @@ module AlgebraicCombination = { t1: t, t2: t, ) => { - let arithmeticOperation = Operation.Algebraic.toFn(arithmeticOperation) - E.R.merge(toSampleSet(t1), toSampleSet(t2))->E.R.bind(((a, b)) => { - SampleSetDist.map2( - ~fn=arithmeticOperation, - ~t1=a, - ~t2=b, - )->GenericDist_Types.Error.resultStringToResultError + let fn = Operation.Algebraic.toFn(arithmeticOperation) + E.R.merge(toSampleSet(t1), toSampleSet(t2)) + ->E.R.bind(((t1, t2)) => { + SampleSetDist.map2(~fn, ~t1, ~t2)->GenericDist_Types.Error.resultStringToResultError }) + ->E.R2.fmap(r => GenericDist_Types.SampleSet(r)) } //I'm (Ozzie) really just guessing here, very little idea what's best @@ -200,15 +199,7 @@ module AlgebraicCombination = { | Some(Error(e)) => Error(Other(e)) | None => switch chooseConvolutionOrMonteCarlo(t1, t2) { - | #CalculateWithMonteCarlo => { - let sampleSetDist: result = runMonteCarlo( - toSampleSetFn, - arithmeticOperation, - t1, - t2, - ) - sampleSetDist->E.R2.fmap(r => GenericDist_Types.SampleSet(r)) - } + | #CalculateWithMonteCarlo => runMonteCarlo(toSampleSetFn, arithmeticOperation, t1, t2) | #CalculateWithConvolution => runConvolution( toPointSetFn, @@ -274,7 +265,7 @@ let mixture = ( ~pointwiseAddFn: pointwiseAddFn, ) => { if E.A.length(values) == 0 { - Error(GenericDist_Types.Other("mixture must have at least 1 element")) + Error(GenericDist_Types.Other("Mixture error: mixture must have at least 1 element")) } else { let totalWeight = values->E.A2.fmap(E.Tuple2.second)->E.A.Floats.sum let properlyWeightedValues = diff --git a/packages/squiggle-lang/src/rescript/Distributions/GenericDist/GenericDist.resi b/packages/squiggle-lang/src/rescript/Distributions/GenericDist/GenericDist.resi index b65489e3..4565ec14 100644 --- a/packages/squiggle-lang/src/rescript/Distributions/GenericDist/GenericDist.resi +++ b/packages/squiggle-lang/src/rescript/Distributions/GenericDist/GenericDist.resi @@ -5,7 +5,8 @@ type toSampleSetFn = t => result type scaleMultiplyFn = (t, float) => result type pointwiseAddFn = (t, t) => result -let sampleN: (t, int) => result, error> +let sampleN: (t, int) => array + let toSampleSetDist: (t, int) => Belt.Result.t let fromFloat: float => t diff --git a/packages/squiggle-lang/src/rescript/Distributions/SampleSetDist/SampleSetDist.res b/packages/squiggle-lang/src/rescript/Distributions/SampleSetDist/SampleSetDist.res index ccf1b775..7a63332f 100644 --- a/packages/squiggle-lang/src/rescript/Distributions/SampleSetDist/SampleSetDist.res +++ b/packages/squiggle-lang/src/rescript/Distributions/SampleSetDist/SampleSetDist.res @@ -1,4 +1,5 @@ module T: { + @genType type t let make: array => result let get: t => array @@ -18,7 +19,7 @@ include T let length = (t: t) => get(t) |> E.A.length // TODO: Refactor to get error in the toPointSetDist function, instead of adding at very end. -let toPointSetDist = (~samples: t, ~samplingInputs: SamplingInputs.samplingInputs, ()): result< +let toPointSetDist = (~samples: t, ~samplingInputs: SamplingInputs.samplingInputs): result< PointSetTypes.pointSetDist, string, > => diff --git a/packages/squiggle-lang/src/rescript/OldInterpreter/ASTTypes.res b/packages/squiggle-lang/src/rescript/OldInterpreter/ASTTypes.res index 3c14343a..17477f8f 100644 --- a/packages/squiggle-lang/src/rescript/OldInterpreter/ASTTypes.res +++ b/packages/squiggle-lang/src/rescript/OldInterpreter/ASTTypes.res @@ -225,7 +225,7 @@ module SamplingDistribution = { let pointSetDist = sampleSetDist -> E.R.bind(r => - SampleSetDist.toPointSetDist(~samplingInputs=evaluationParams.samplingInputs, ~samples=r, ())); + SampleSetDist.toPointSetDist(~samplingInputs=evaluationParams.samplingInputs, ~samples=r)); pointSetDist |> E.R.fmap(r => #Normalize(#RenderedDist(r))) }) }