From c95c56cfb8ad214741fa4b458981511b0f40ca89 Mon Sep 17 00:00:00 2001 From: Quinn Dougherty Date: Wed, 4 May 2022 13:02:58 -0400 Subject: [PATCH] Added `logScaleWithThreshold(eps)` and completed renaming to `klDivergence` Value: [1e-5 to 1e-3] --- .../DistributionOperation.res | 11 ++++++++- .../Distributions/DistributionTypes.res | 7 ++++++ .../Distributions/GenericDist/GenericDist.res | 19 ++++++++------- .../GenericDist/GenericDist.resi | 2 +- .../Distributions/PointSetDist/Continuous.res | 2 +- .../Distributions/PointSetDist/Discrete.res | 2 +- .../PointSetDist/Distributions.res | 4 ++-- .../Distributions/PointSetDist/Mixed.res | 2 +- .../PointSetDist/PointSetDist.res | 10 ++++---- .../ReducerInterface_GenericDistribution.res | 2 ++ .../src/rescript/Utility/Operation.res | 23 +++++++++++++++---- 11 files changed, 60 insertions(+), 24 deletions(-) diff --git a/packages/squiggle-lang/src/rescript/Distributions/DistributionOperation/DistributionOperation.res b/packages/squiggle-lang/src/rescript/Distributions/DistributionOperation/DistributionOperation.res index fa6f0f39..6872136e 100644 --- a/packages/squiggle-lang/src/rescript/Distributions/DistributionOperation/DistributionOperation.res +++ b/packages/squiggle-lang/src/rescript/Distributions/DistributionOperation/DistributionOperation.res @@ -145,7 +145,7 @@ let rec run = (~env, functionCallInfo: functionCallInfo): outputType => { } | ToDist(Normalize) => dist->GenericDist.normalize->Dist | ToScore(KLDivergence(t2)) => - GenericDist.logScore(dist, t2, ~toPointSetFn) + GenericDist.klDivergence(dist, t2, ~toPointSetFn) ->E.R2.fmap(r => Float(r)) ->OutputLocal.fromResult | ToBool(IsNormalized) => dist->GenericDist.isNormalized->Bool @@ -163,6 +163,15 @@ let rec run = (~env, functionCallInfo: functionCallInfo): outputType => { ->GenericDist.toPointSet(~xyPointLength, ~sampleCount, ()) ->E.R2.fmap(r => Dist(PointSet(r))) ->OutputLocal.fromResult + | ToDist(Scale(#LogarithmWithThreshold(eps), f)) => + dist + ->GenericDist.pointwiseCombinationFloat( + ~toPointSetFn, + ~algebraicCombination=#LogarithmWithThreshold(eps), + ~f, + ) + ->E.R2.fmap(r => Dist(r)) + ->OutputLocal.fromResult | ToDist(Scale(#Logarithm, f)) => dist ->GenericDist.pointwiseCombinationFloat(~toPointSetFn, ~algebraicCombination=#Logarithm, ~f) diff --git a/packages/squiggle-lang/src/rescript/Distributions/DistributionTypes.res b/packages/squiggle-lang/src/rescript/Distributions/DistributionTypes.res index 9e004288..03210270 100644 --- a/packages/squiggle-lang/src/rescript/Distributions/DistributionTypes.res +++ b/packages/squiggle-lang/src/rescript/Distributions/DistributionTypes.res @@ -72,6 +72,7 @@ module DistributionOperation = { type toScaleFn = [ | #Power | #Logarithm + | #LogarithmWithThreshold(float) ] type toDist = @@ -126,6 +127,8 @@ module DistributionOperation = { | ToDist(Inspect) => `inspect` | ToDist(Scale(#Power, r)) => `scalePower(${E.Float.toFixed(r)})` | ToDist(Scale(#Logarithm, r)) => `scaleLog(${E.Float.toFixed(r)})` + | ToDist(Scale(#LogarithmWithThreshold(eps), r)) => + `scaleLogWithThreshold(${E.Float.toFixed(r)}, epsilon=${E.Float.toFixed(eps)})` | ToString(ToString) => `toString` | ToString(ToSparkline(n)) => `toSparkline(${E.I.toString(n)})` | ToBool(IsNormalized) => `isNormalized` @@ -160,6 +163,10 @@ module Constructors = { let logScore = (dist1, dist2): t => FromDist(ToScore(KLDivergence(dist2)), dist1) let scalePower = (dist, n): t => FromDist(ToDist(Scale(#Power, n)), dist) let scaleLogarithm = (dist, n): t => FromDist(ToDist(Scale(#Logarithm, n)), dist) + let scaleLogarithmWithThreshold = (dist, n, eps): t => FromDist( + ToDist(Scale(#LogarithmWithThreshold(eps), n)), + dist, + ) let toString = (dist): t => FromDist(ToString(ToString), dist) let toSparkline = (dist, n): t => FromDist(ToString(ToSparkline(n)), dist) let algebraicAdd = (dist1, dist2: genericDist): t => FromDist( diff --git a/packages/squiggle-lang/src/rescript/Distributions/GenericDist/GenericDist.res b/packages/squiggle-lang/src/rescript/Distributions/GenericDist/GenericDist.res index b30a46b4..2085d72c 100644 --- a/packages/squiggle-lang/src/rescript/Distributions/GenericDist/GenericDist.res +++ b/packages/squiggle-lang/src/rescript/Distributions/GenericDist/GenericDist.res @@ -59,10 +59,10 @@ let integralEndY = (t: t): float => let isNormalized = (t: t): bool => Js.Math.abs_float(integralEndY(t) -. 1.0) < 1e-7 -let logScore = (t1, t2, ~toPointSetFn: toPointSetFn): result => { +let klDivergence = (t1, t2, ~toPointSetFn: toPointSetFn): result => { let pointSets = E.R.merge(toPointSetFn(t1), toPointSetFn(t2)) pointSets |> E.R2.bind(((a, b)) => - PointSetDist.T.logScore(a, b)->E.R2.errMap(x => DistributionTypes.OperationError(x)) + PointSetDist.T.klDivergence(a, b)->E.R2.errMap(x => DistributionTypes.OperationError(x)) ) } @@ -391,14 +391,12 @@ let pointwiseCombinationFloat = ( ~algebraicCombination: Operation.algebraicOperation, ~f: float, ): result => { - let m = switch algebraicCombination { - | #Add | #Subtract => Error(DistributionTypes.DistributionVerticalShiftIsInvalid) - | (#Multiply | #Divide | #Power | #Logarithm) as arithmeticOperation => + let executeCombination = arithOp => toPointSetFn(t)->E.R.bind(t => { //TODO: Move to PointSet codebase - let fn = (secondary, main) => Operation.Scale.toFn(arithmeticOperation, main, secondary) - let integralSumCacheFn = Operation.Scale.toIntegralSumCacheFn(arithmeticOperation) - let integralCacheFn = Operation.Scale.toIntegralCacheFn(arithmeticOperation) + let fn = (secondary, main) => Operation.Scale.toFn(arithOp, main, secondary) + let integralSumCacheFn = Operation.Scale.toIntegralSumCacheFn(arithOp) + let integralCacheFn = Operation.Scale.toIntegralCacheFn(arithOp) PointSetDist.T.mapYResult( ~integralSumCacheFn=integralSumCacheFn(f), ~integralCacheFn=integralCacheFn(f), @@ -406,6 +404,11 @@ let pointwiseCombinationFloat = ( t, )->E.R2.errMap(x => DistributionTypes.OperationError(x)) }) + let m = switch algebraicCombination { + | #Add | #Subtract => Error(DistributionTypes.DistributionVerticalShiftIsInvalid) + | (#Multiply | #Divide | #Power | #Logarithm) as arithmeticOperation => + executeCombination(arithmeticOperation) + | #LogarithmWithThreshold(eps) => executeCombination(#LogarithmWithThreshold(eps)) } m->E.R2.fmap(r => DistributionTypes.PointSet(r)) } diff --git a/packages/squiggle-lang/src/rescript/Distributions/GenericDist/GenericDist.resi b/packages/squiggle-lang/src/rescript/Distributions/GenericDist/GenericDist.resi index cd1b72a9..03bc5fe8 100644 --- a/packages/squiggle-lang/src/rescript/Distributions/GenericDist/GenericDist.resi +++ b/packages/squiggle-lang/src/rescript/Distributions/GenericDist/GenericDist.resi @@ -23,7 +23,7 @@ let toFloatOperation: ( ~distToFloatOperation: DistributionTypes.DistributionOperation.toFloat, ) => result -let logScore: (t, t, ~toPointSetFn: toPointSetFn) => result +let klDivergence: (t, t, ~toPointSetFn: toPointSetFn) => result @genType let toPointSet: ( diff --git a/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Continuous.res b/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Continuous.res index 2afd1008..93f6c1c9 100644 --- a/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Continuous.res +++ b/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Continuous.res @@ -270,7 +270,7 @@ module T = Dist({ let variance = (t: t): float => XYShape.Analysis.getVarianceDangerously(t, mean, Analysis.getMeanOfSquares) - let logScore = (base: t, reference: t) => { + let klDivergence = (base: t, reference: t) => { let referenceIsZero = switch Distributions.Common.isZeroEverywhere( PointSetTypes.Continuous(reference), ) { diff --git a/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Discrete.res b/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Discrete.res index 0f2e2eb3..8ec9410f 100644 --- a/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Discrete.res +++ b/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Discrete.res @@ -229,7 +229,7 @@ module T = Dist({ XYShape.Analysis.getVarianceDangerously(t, mean, getMeanOfSquares) } - let logScore = (base: t, reference: t) => { + let klDivergence = (base: t, reference: t) => { let referenceIsZero = switch Distributions.Common.isZeroEverywhere( PointSetTypes.Discrete(reference), ) { diff --git a/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Distributions.res b/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Distributions.res index 0ad3b24a..5c9ee1aa 100644 --- a/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Distributions.res +++ b/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Distributions.res @@ -33,7 +33,7 @@ module type dist = { let mean: t => float let variance: t => float - let logScore: (t, t) => result + let klDivergence: (t, t) => result } module Dist = (T: dist) => { @@ -56,7 +56,7 @@ module Dist = (T: dist) => { let mean = T.mean let variance = T.variance let integralEndY = T.integralEndY - let logScore = T.logScore + let klDivergence = T.klDivergence let updateIntegralCache = T.updateIntegralCache diff --git a/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Mixed.res b/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Mixed.res index a3b41401..66f353f6 100644 --- a/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Mixed.res +++ b/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Mixed.res @@ -301,7 +301,7 @@ module T = Dist({ } } - let logScore = (base: t, reference: t) => { + let klDivergence = (base: t, reference: t) => { let referenceIsZero = switch Distributions.Common.isZeroEverywhere( PointSetTypes.Mixed(reference), ) { diff --git a/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/PointSetDist.res b/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/PointSetDist.res index d7baad9f..7b905316 100644 --- a/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/PointSetDist.res +++ b/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/PointSetDist.res @@ -196,15 +196,15 @@ module T = Dist({ | Continuous(m) => Continuous.T.variance(m) } - let logScore = (t1: t, t2: t) => + let klDivergence = (t1: t, t2: t) => switch (t1, t2) { - | (Continuous(t1), Continuous(t2)) => Continuous.T.logScore(t1, t2) - | (Discrete(t1), Discrete(t2)) => Discrete.T.logScore(t1, t2) - | (Mixed(t1), Mixed(t2)) => Mixed.T.logScore(t1, t2) + | (Continuous(t1), Continuous(t2)) => Continuous.T.klDivergence(t1, t2) + | (Discrete(t1), Discrete(t2)) => Discrete.T.klDivergence(t1, t2) + | (Mixed(t1), Mixed(t2)) => Mixed.T.klDivergence(t1, t2) | _ => { let t1 = toMixed(t1) let t2 = toMixed(t2) - Mixed.T.logScore(t1, t2) + Mixed.T.klDivergence(t1, t2) } } }) diff --git a/packages/squiggle-lang/src/rescript/ReducerInterface/ReducerInterface_GenericDistribution.res b/packages/squiggle-lang/src/rescript/ReducerInterface/ReducerInterface_GenericDistribution.res index 1d99022a..1175767d 100644 --- a/packages/squiggle-lang/src/rescript/ReducerInterface/ReducerInterface_GenericDistribution.res +++ b/packages/squiggle-lang/src/rescript/ReducerInterface/ReducerInterface_GenericDistribution.res @@ -222,6 +222,8 @@ let dispatchToGenericOutput = (call: ExpressionValue.functionCall, _environment) | ("scaleLog10", [EvDistribution(dist)]) => Helpers.toDistFn(Scale(#Logarithm, 10.0), dist) | ("scaleLog", [EvDistribution(dist), EvNumber(float)]) => Helpers.toDistFn(Scale(#Logarithm, float), dist) + | ("scaleLogWithThreshold", [EvDistribution(dist), EvNumber(base), EvNumber(eps)]) => + Helpers.toDistFn(Scale(#LogarithmWithThreshold(eps), base), dist) | ("scalePow", [EvDistribution(dist), EvNumber(float)]) => Helpers.toDistFn(Scale(#Power, float), dist) | ("scaleExp", [EvDistribution(dist)]) => diff --git a/packages/squiggle-lang/src/rescript/Utility/Operation.res b/packages/squiggle-lang/src/rescript/Utility/Operation.res index bd4e799f..f124ffcd 100644 --- a/packages/squiggle-lang/src/rescript/Utility/Operation.res +++ b/packages/squiggle-lang/src/rescript/Utility/Operation.res @@ -8,6 +8,7 @@ type algebraicOperation = [ | #Divide | #Power | #Logarithm + | #LogarithmWithThreshold(float) ] type convolutionOperation = [ @@ -18,7 +19,7 @@ type convolutionOperation = [ @genType type pointwiseOperation = [#Add | #Multiply | #Power] -type scaleOperation = [#Multiply | #Power | #Logarithm | #Divide] +type scaleOperation = [#Multiply | #Power | #Logarithm | #LogarithmWithThreshold(float) | #Divide] type distToFloatOperation = [ | #Pdf(float) | #Cdf(float) @@ -35,7 +36,7 @@ module Convolution = { | #Add => Some(#Add) | #Subtract => Some(#Subtract) | #Multiply => Some(#Multiply) - | #Divide | #Power | #Logarithm => None + | #Divide | #Power | #Logarithm | #LogarithmWithThreshold(_) => None } let canDoAlgebraicOperation = (op: algebraicOperation): bool => @@ -108,6 +109,12 @@ module Algebraic = { | #Power => power(a, b) | #Divide => divide(a, b) | #Logarithm => logarithm(a, b) + | #LogarithmWithThreshold(eps) => + if a < eps { + Ok(0.0) + } else { + logarithm(a, b) + } } let toString = x => @@ -118,6 +125,7 @@ module Algebraic = { | #Power => "**" | #Divide => "/" | #Logarithm => "log" + | #LogarithmWithThreshold(_) => "log" } let format = (a, b, c) => b ++ (" " ++ (toString(a) ++ (" " ++ c))) @@ -162,6 +170,12 @@ module Scale = { } else { logarithm(a, b) } + | #LogarithmWithThreshold(eps) => + if a < eps { + Ok(0.0) + } else { + logarithm(a, b) + } } let format = (operation: t, value, scaleBy) => @@ -170,14 +184,14 @@ module Scale = { | #Divide => j`verticalDivide($value, $scaleBy) ` | #Power => j`verticalPower($value, $scaleBy) ` | #Logarithm => j`verticalLog($value, $scaleBy) ` + | #LogarithmWithThreshold(eps) => j`verticalLog($value, $scaleBy, epsilon=$eps) ` } let toIntegralSumCacheFn = x => switch x { | #Multiply => (a, b) => Some(a *. b) | #Divide => (a, b) => Some(a /. b) - | #Power => (_, _) => None - | #Logarithm => (_, _) => None + | #Power | #Logarithm | #LogarithmWithThreshold(_) => (_, _) => None } let toIntegralCacheFn = x => @@ -186,6 +200,7 @@ module Scale = { | #Divide => (_, _) => None | #Power => (_, _) => None | #Logarithm => (_, _) => None + | #LogarithmWithThreshold(_) => (_, _) => None } }