diff --git a/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Continuous.res b/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Continuous.res index 1467239e..06c54040 100644 --- a/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Continuous.res +++ b/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Continuous.res @@ -271,12 +271,22 @@ module T = Dist({ XYShape.Analysis.getVarianceDangerously(t, mean, Analysis.getMeanOfSquares) let logScore = (base: t, reference: t) => { - E.R2.bind( - combinePointwise(PointSetDist_Scoring.LogScoring.multiply, reference), - combinePointwise(PointSetDist_Scoring.LogScoring.logScore, base, reference), - ) - |> E.R.fmap(shapeMap(XYShape.T.filterYValues(Js.Float.isFinite))) - |> E.R.fmap(integralEndY) + let referenceIsZero = switch Distributions.Common.isZeroEverywhere( + PointSetTypes.Continuous(reference), + ) { + | Continuous(b) => b + | _ => false + } + if referenceIsZero { + Ok(0.0) + } else { + E.R2.bind( + combinePointwise(PointSetDist_Scoring.LogScoring.multiply, reference), + combinePointwise(PointSetDist_Scoring.LogScoring.logScore, base, reference), + ) + |> E.R.fmap(shapeMap(XYShape.T.filterYValues(Js.Float.isFinite))) + |> E.R.fmap(integralEndY) + } } }) diff --git a/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Discrete.res b/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Discrete.res index a6c04e6b..a74283dd 100644 --- a/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Discrete.res +++ b/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Discrete.res @@ -230,8 +230,19 @@ module T = Dist({ } let logScore = (base: t, reference: t) => { - combinePointwise(~fn=PointSetDist_Scoring.LogScoring.logScore, base, reference) |> E.R2.bind( - integralEndYResult, - ) + let referenceIsZero = switch Distributions.Common.isZeroEverywhere( + PointSetTypes.Discrete(reference), + ) { + | Discrete(b) => b + | _ => false + } + if referenceIsZero { + Ok(0.0) + } else { + E.R2.bind( + combinePointwise(~fn=PointSetDist_Scoring.LogScoring.multiply, reference), + combinePointwise(~fn=PointSetDist_Scoring.LogScoring.logScore, base, reference), + ) |> E.R2.bind(integralEndYResult) + } } }) diff --git a/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Distributions.res b/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Distributions.res index 15da318d..0ad3b24a 100644 --- a/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Distributions.res +++ b/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Distributions.res @@ -96,4 +96,18 @@ module Common = { None | (Some(s1), Some(s2)) => combineFn(s1, s2) } + + let isZeroEverywhere = (d: PointSetTypes.pointSetDist) => { + let isZero = (x: float): bool => x == 0.0 + PointSetTypes.ShapeMonad.fmap( + d, + ( + mixed => + E.A.all(isZero, mixed.continuous.xyShape.ys) && + E.A.all(isZero, mixed.discrete.xyShape.ys), + disc => E.A.all(isZero, disc.xyShape.ys), + cont => E.A.all(isZero, cont.xyShape.ys), + ), + ) + } } diff --git a/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Mixed.res b/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Mixed.res index 3d36d421..0e528496 100644 --- a/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Mixed.res +++ b/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Mixed.res @@ -302,9 +302,20 @@ module T = Dist({ } let logScore = (base: t, reference: t) => { - combinePointwise(PointSetDist_Scoring.LogScoring.logScore, base, reference) |> E.R.fmap( - integralEndY, - ) + let referenceIsZero = switch Distributions.Common.isZeroEverywhere( + PointSetTypes.Mixed(reference), + ) { + | Mixed(b) => b + | _ => false + } + if referenceIsZero { + Ok(0.0) + } else { + E.R2.bind( + combinePointwise(PointSetDist_Scoring.LogScoring.multiply, reference), + combinePointwise(PointSetDist_Scoring.LogScoring.logScore, base, reference), + ) |> E.R.fmap(integralEndY) + } } }) diff --git a/packages/squiggle-lang/src/rescript/Utility/E.res b/packages/squiggle-lang/src/rescript/Utility/E.res index 472c32f7..076fc89f 100644 --- a/packages/squiggle-lang/src/rescript/Utility/E.res +++ b/packages/squiggle-lang/src/rescript/Utility/E.res @@ -606,6 +606,9 @@ module A = { let filter = Js.Array.filter let joinWith = Js.Array.joinWith + let all = (p: 'a => bool, xs: array<'a>): bool => length(filter(p, xs)) == length(xs) + let any = (p: 'a => bool, xs: array<'a>): bool => length(filter(p, xs)) > 0 + module O = { let concatSomes = (optionals: array>): array<'a> => optionals