diff --git a/packages/squiggle-lang/__tests__/TestHelpers.res b/packages/squiggle-lang/__tests__/TestHelpers.res index 4502c68a..c9ed718e 100644 --- a/packages/squiggle-lang/__tests__/TestHelpers.res +++ b/packages/squiggle-lang/__tests__/TestHelpers.res @@ -61,18 +61,13 @@ let lognormalMake = SymbolicDist.Lognormal.make let triangularMake = SymbolicDist.Triangular.make let floatMake = SymbolicDist.Float.make -let normalMakeR = (mean, stdev) => - E.R.fmap(s => DistributionTypes.Symbolic(s), SymbolicDist.Normal.make(mean, stdev)) -let betaMakeR = (alpha, beta) => - E.R.fmap(s => DistributionTypes.Symbolic(s), SymbolicDist.Beta.make(alpha, beta)) -let exponentialMakeR = rate => - E.R.fmap(s => DistributionTypes.Symbolic(s), SymbolicDist.Exponential.make(rate)) -let uniformMakeR = (low, high) => - E.R.fmap(s => DistributionTypes.Symbolic(s), SymbolicDist.Uniform.make(low, high)) -let cauchyMakeR = (local, rate) => - E.R.fmap(s => DistributionTypes.Symbolic(s), SymbolicDist.Cauchy.make(local, rate)) -let lognormalMakeR = (mu, sigma) => - E.R.fmap(s => DistributionTypes.Symbolic(s), SymbolicDist.Lognormal.make(mu, sigma)) +let fmapGenDist = symbdistres => E.R.fmap(s => DistributionTypes.Symbolic(s), symbdistres) +let normalMakeR = (mean, stdev) => fmapGenDist(SymbolicDist.Normal.make(mean, stdev)) +let betaMakeR = (alpha, beta) => fmapGenDist(SymbolicDist.Beta.make(alpha, beta)) +let exponentialMakeR = rate => fmapGenDist(SymbolicDist.Exponential.make(rate)) +let uniformMakeR = (low, high) => fmapGenDist(SymbolicDist.Uniform.make(low, high)) +let cauchyMakeR = (local, rate) => fmapGenDist(SymbolicDist.Cauchy.make(local, rate)) +let lognormalMakeR = (mu, sigma) => fmapGenDist(SymbolicDist.Lognormal.make(mu, sigma)) let triangularMakeR = (low, mode, high) => - E.R.fmap(s => DistributionTypes.Symbolic(s), SymbolicDist.Triangular.make(low, mode, high)) + fmapGenDist(SymbolicDist.Triangular.make(low, mode, high)) // let floatMakeR = x =>E.R.fmap(s => DistributionTypes.Symbolic(s), SymbolicDist.Float.make(x)) diff --git a/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Discrete.res b/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Discrete.res index 4eef1f91..a6c04e6b 100644 --- a/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Discrete.res +++ b/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Discrete.res @@ -233,6 +233,5 @@ module T = Dist({ combinePointwise(~fn=PointSetDist_Scoring.LogScoring.logScore, base, reference) |> E.R2.bind( integralEndYResult, ) - // |> (r => Ok(r)) } }) diff --git a/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/PointSetDist_Scoring.res b/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/PointSetDist_Scoring.res index d025d835..1524243d 100644 --- a/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/PointSetDist_Scoring.res +++ b/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/PointSetDist_Scoring.res @@ -1,8 +1,14 @@ module LogScoring = { let logFn = Js.Math.log let subtraction = (a, b) => Ok(a -. b) - let logScore = (a: float, b: float): result => Ok( - logFn(Js.Math.abs_float(a /. b)), - ) + let logScore = (a: float, b: float): result => + if a == 0.0 { + Error(Operation.Error.NegativeInfinityError) + } else if b == 0.0 { + Error(Operation.Error.DivideByZeroError) + } else { + let quot = a /. b + quot < 0.0 ? Error(OperationError.ComplexNumberError) : Ok(logFn(quot)) + } let multiply = (a: float, b: float): result => Ok(a *. b) } diff --git a/packages/squiggle-lang/src/rescript/Utility/Operation.res b/packages/squiggle-lang/src/rescript/Utility/Operation.res index 8f67c340..a14fdcdc 100644 --- a/packages/squiggle-lang/src/rescript/Utility/Operation.res +++ b/packages/squiggle-lang/src/rescript/Utility/Operation.res @@ -53,6 +53,7 @@ type operationError = | DivisionByZeroError | ComplexNumberError | InfinityError + | NegativeInfinityError @genType module Error = { @@ -63,7 +64,8 @@ module Error = { switch err { | DivisionByZeroError => "Cannot divide by zero" | ComplexNumberError => "Operation returned complex result" - | InfinityError => "Operation returned + or - infinity" + | InfinityError => "Operation returned positive infinity" + | NegativeInfinityError => "Operation returned negative infinity" } } @@ -89,7 +91,7 @@ let logarithm = (a: float, b: float): result => } else if a > 0.0 && b > 0.0 { Ok(log(a) /. log(b)) } else if a == 0.0 { - Error(InfinityError) + Error(NegativeInfinityError) } else { Error(ComplexNumberError) }