diff --git a/packages/squiggle-lang/src/rescript/Distributions/GenericDist/GenericDist.res b/packages/squiggle-lang/src/rescript/Distributions/GenericDist/GenericDist.res index beb27980..4d2177f5 100644 --- a/packages/squiggle-lang/src/rescript/Distributions/GenericDist/GenericDist.res +++ b/packages/squiggle-lang/src/rescript/Distributions/GenericDist/GenericDist.res @@ -83,7 +83,8 @@ let toPointSet = ( pointSetDistLength: xyPointLength, kernelWidth: None, }, - )->GenericDist_Types.Error.resultStringToResultError + ) + ->GenericDist_Types.Error.resultStringToResultError } } @@ -161,10 +162,12 @@ module AlgebraicCombination = { arithmeticOperation: GenericDist_Types.Operation.arithmeticOperation, t1: t, t2: t, - ) => + ) => { + let normalize = PointSetDist.T.normalize E.R.merge(toPointSet(t1), toPointSet(t2))->E.R2.fmap(((a, b)) => - PointSetDist.combineAlgebraically(arithmeticOperation, a, b) + PointSetDist.combineAlgebraically(arithmeticOperation, normalize(a), normalize(b))->normalize ) + } let runMonteCarlo = ( toSampleSet: toSampleSetFn, @@ -196,6 +199,50 @@ module AlgebraicCombination = { ? #CalculateWithMonteCarlo : #CalculateWithConvolution + let getLogarithmInputError = (t1: t, t2: t, ~toPointSetFn: toPointSetFn): option => { + let firstOperandIsGreaterThanZero = + toFloatOperation(t1, ~toPointSetFn, ~distToFloatOperation=#Cdf(1e-10)) |> E.R.fmap(r => + r > 0. + ) + let secondOperandIsGreaterThanZero = + toFloatOperation(t2, ~toPointSetFn, ~distToFloatOperation=#Cdf(1e-10)) |> E.R.fmap(r => + r > 0. + ) + let secondOperandHasMassAt1 = + toFloatOperation(t2, ~toPointSetFn, ~distToFloatOperation=#Pdf(1.0)) |> E.R.fmap(r => + r >= 1e-10 + ) + let items = E.A.R.firstErrorOrOpen([ + firstOperandIsGreaterThanZero, + secondOperandIsGreaterThanZero, + secondOperandHasMassAt1, + ]) + Js.log2("PMASS", toFloatOperation(t2, ~toPointSetFn, ~distToFloatOperation=#Pdf(1.0))) + Js.log4("HIHI", items, t1, t2) + switch items { + | Error(r) => Some(r) + | Ok([true, _, _]) => Some(Other("First input of logarithm must be fully greater than 0")) + | Ok([false, true, _]) => Some(Other("Second input of logarithm must be fully greater than 0")) + | Ok([false, false, true]) => + Some(Other("Second input of logarithm cannot have probability mass at 1.0")) + | Ok([false, false, false]) => None + | Ok(_) => Some(Unreachable) + } + } + + let getInvalidOperationError = ( + t1: t, + t2: t, + ~toPointSetFn: toPointSetFn, + ~arithmeticOperation, + ): option => { + if arithmeticOperation == #Logarithm { + getLogarithmInputError(t1, t2, ~toPointSetFn) + } else { + None + } + } + let run = ( t1: t, ~toPointSetFn: toPointSetFn, @@ -207,15 +254,19 @@ module AlgebraicCombination = { | Some(Ok(symbolicDist)) => Ok(Symbolic(symbolicDist)) | Some(Error(e)) => Error(Other(e)) | None => - switch chooseConvolutionOrMonteCarlo(t1, t2) { - | #CalculateWithMonteCarlo => runMonteCarlo(toSampleSetFn, arithmeticOperation, t1, t2) - | #CalculateWithConvolution => - runConvolution( - toPointSetFn, - arithmeticOperation, - t1, - t2, - )->E.R2.fmap(r => DistributionTypes.PointSet(r)) + switch getInvalidOperationError(t1, t2, ~toPointSetFn, ~arithmeticOperation) { + | Some(e) => Error(e) + | None => + switch chooseConvolutionOrMonteCarlo(t1, t2) { + | #CalculateWithMonteCarlo => runMonteCarlo(toSampleSetFn, arithmeticOperation, t1, t2) + | #CalculateWithConvolution => + runConvolution( + toPointSetFn, + arithmeticOperation, + t1, + t2, + )->E.R2.fmap(r => DistributionTypes.PointSet(PointSetDist.T.normalize(r))) + } } } } diff --git a/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/PointSetDist.res b/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/PointSetDist.res index d0668a57..d9d7eec8 100644 --- a/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/PointSetDist.res +++ b/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/PointSetDist.res @@ -93,7 +93,20 @@ module T = Dist({ t, ) - let normalize = fmap((Mixed.T.normalize, Discrete.T.normalize, Continuous.T.normalize)) + let integralEndY = mapToAll(( + Mixed.T.Integral.sum, + Discrete.T.Integral.sum, + Continuous.T.Integral.sum, + )) + + let isNormalized = t => integralEndY(t) == 1.0 + + let normalize = (t: t): t => + if isNormalized(t) { + t + } else { + t |> fmap((Mixed.T.normalize, Discrete.T.normalize, Continuous.T.normalize)) + } let updateIntegralCache = (integralCache, t: t): t => fmap( @@ -124,11 +137,6 @@ module T = Dist({ Discrete.T.Integral.get, Continuous.T.Integral.get, )) - let integralEndY = mapToAll(( - Mixed.T.Integral.sum, - Discrete.T.Integral.sum, - Continuous.T.Integral.sum, - )) let integralXtoY = f => mapToAll((Mixed.T.Integral.xToY(f), Discrete.T.Integral.xToY(f), Continuous.T.Integral.xToY(f))) let integralYtoX = f => diff --git a/packages/squiggle-lang/src/rescript/Distributions/SampleSetDist/SampleSetDist.res b/packages/squiggle-lang/src/rescript/Distributions/SampleSetDist/SampleSetDist.res index fcd4055d..8907d85e 100644 --- a/packages/squiggle-lang/src/rescript/Distributions/SampleSetDist/SampleSetDist.res +++ b/packages/squiggle-lang/src/rescript/Distributions/SampleSetDist/SampleSetDist.res @@ -64,5 +64,10 @@ let sampleN = (t: t, n) => { //TODO: Figure out what to do if distributions are different lengths. ``zip`` is kind of inelegant for this. let map2 = (~fn: (float, float) => float, ~t1: t, ~t2: t) => { let samples = Belt.Array.zip(get(t1), get(t2))->E.A2.fmap(((a, b)) => fn(a, b)) - make(samples) + let has_invalid_results = Belt.Array.some(samples, a => Js.Float.isNaN(a)) + if has_invalid_results { + Error("Distribution combination produced invalid results") + } else { + make(samples) + } } diff --git a/packages/squiggle-lang/src/rescript/Distributions/SampleSetDist/SampleSetDist_ToPointSet.res b/packages/squiggle-lang/src/rescript/Distributions/SampleSetDist/SampleSetDist_ToPointSet.res index 59b4fa46..9138aecc 100644 --- a/packages/squiggle-lang/src/rescript/Distributions/SampleSetDist/SampleSetDist_ToPointSet.res +++ b/packages/squiggle-lang/src/rescript/Distributions/SampleSetDist/SampleSetDist_ToPointSet.res @@ -133,9 +133,12 @@ let toPointSetDist = ( ~discrete=Some(discrete), ) + //The latter doesn't always produce a normalized result, so we need to normalize it. + let normalized = pointSetDist->E.O2.fmap(PointSetDist.T.normalize) + let samplesParse: Internals.Types.outputs = { continuousParseParams: pdf |> E.O.fmap(snd), - pointSetDist: pointSetDist, + pointSetDist: normalized, } samplesParse