Added logScaleWithThreshold(eps) and completed renaming to

`klDivergence`

Value: [1e-5 to 1e-3]
This commit is contained in:
Quinn Dougherty 2022-05-04 13:02:58 -04:00
parent 236be470d5
commit c95c56cfb8
11 changed files with 60 additions and 24 deletions

View File

@ -145,7 +145,7 @@ let rec run = (~env, functionCallInfo: functionCallInfo): outputType => {
} }
| ToDist(Normalize) => dist->GenericDist.normalize->Dist | ToDist(Normalize) => dist->GenericDist.normalize->Dist
| ToScore(KLDivergence(t2)) => | ToScore(KLDivergence(t2)) =>
GenericDist.logScore(dist, t2, ~toPointSetFn) GenericDist.klDivergence(dist, t2, ~toPointSetFn)
->E.R2.fmap(r => Float(r)) ->E.R2.fmap(r => Float(r))
->OutputLocal.fromResult ->OutputLocal.fromResult
| ToBool(IsNormalized) => dist->GenericDist.isNormalized->Bool | ToBool(IsNormalized) => dist->GenericDist.isNormalized->Bool
@ -163,6 +163,15 @@ let rec run = (~env, functionCallInfo: functionCallInfo): outputType => {
->GenericDist.toPointSet(~xyPointLength, ~sampleCount, ()) ->GenericDist.toPointSet(~xyPointLength, ~sampleCount, ())
->E.R2.fmap(r => Dist(PointSet(r))) ->E.R2.fmap(r => Dist(PointSet(r)))
->OutputLocal.fromResult ->OutputLocal.fromResult
| ToDist(Scale(#LogarithmWithThreshold(eps), f)) =>
dist
->GenericDist.pointwiseCombinationFloat(
~toPointSetFn,
~algebraicCombination=#LogarithmWithThreshold(eps),
~f,
)
->E.R2.fmap(r => Dist(r))
->OutputLocal.fromResult
| ToDist(Scale(#Logarithm, f)) => | ToDist(Scale(#Logarithm, f)) =>
dist dist
->GenericDist.pointwiseCombinationFloat(~toPointSetFn, ~algebraicCombination=#Logarithm, ~f) ->GenericDist.pointwiseCombinationFloat(~toPointSetFn, ~algebraicCombination=#Logarithm, ~f)

View File

@ -72,6 +72,7 @@ module DistributionOperation = {
type toScaleFn = [ type toScaleFn = [
| #Power | #Power
| #Logarithm | #Logarithm
| #LogarithmWithThreshold(float)
] ]
type toDist = type toDist =
@ -126,6 +127,8 @@ module DistributionOperation = {
| ToDist(Inspect) => `inspect` | ToDist(Inspect) => `inspect`
| ToDist(Scale(#Power, r)) => `scalePower(${E.Float.toFixed(r)})` | ToDist(Scale(#Power, r)) => `scalePower(${E.Float.toFixed(r)})`
| ToDist(Scale(#Logarithm, r)) => `scaleLog(${E.Float.toFixed(r)})` | ToDist(Scale(#Logarithm, r)) => `scaleLog(${E.Float.toFixed(r)})`
| ToDist(Scale(#LogarithmWithThreshold(eps), r)) =>
`scaleLogWithThreshold(${E.Float.toFixed(r)}, epsilon=${E.Float.toFixed(eps)})`
| ToString(ToString) => `toString` | ToString(ToString) => `toString`
| ToString(ToSparkline(n)) => `toSparkline(${E.I.toString(n)})` | ToString(ToSparkline(n)) => `toSparkline(${E.I.toString(n)})`
| ToBool(IsNormalized) => `isNormalized` | ToBool(IsNormalized) => `isNormalized`
@ -160,6 +163,10 @@ module Constructors = {
let logScore = (dist1, dist2): t => FromDist(ToScore(KLDivergence(dist2)), dist1) let logScore = (dist1, dist2): t => FromDist(ToScore(KLDivergence(dist2)), dist1)
let scalePower = (dist, n): t => FromDist(ToDist(Scale(#Power, n)), dist) let scalePower = (dist, n): t => FromDist(ToDist(Scale(#Power, n)), dist)
let scaleLogarithm = (dist, n): t => FromDist(ToDist(Scale(#Logarithm, n)), dist) let scaleLogarithm = (dist, n): t => FromDist(ToDist(Scale(#Logarithm, n)), dist)
let scaleLogarithmWithThreshold = (dist, n, eps): t => FromDist(
ToDist(Scale(#LogarithmWithThreshold(eps), n)),
dist,
)
let toString = (dist): t => FromDist(ToString(ToString), dist) let toString = (dist): t => FromDist(ToString(ToString), dist)
let toSparkline = (dist, n): t => FromDist(ToString(ToSparkline(n)), dist) let toSparkline = (dist, n): t => FromDist(ToString(ToSparkline(n)), dist)
let algebraicAdd = (dist1, dist2: genericDist): t => FromDist( let algebraicAdd = (dist1, dist2: genericDist): t => FromDist(

View File

@ -59,10 +59,10 @@ let integralEndY = (t: t): float =>
let isNormalized = (t: t): bool => Js.Math.abs_float(integralEndY(t) -. 1.0) < 1e-7 let isNormalized = (t: t): bool => Js.Math.abs_float(integralEndY(t) -. 1.0) < 1e-7
let logScore = (t1, t2, ~toPointSetFn: toPointSetFn): result<float, error> => { let klDivergence = (t1, t2, ~toPointSetFn: toPointSetFn): result<float, error> => {
let pointSets = E.R.merge(toPointSetFn(t1), toPointSetFn(t2)) let pointSets = E.R.merge(toPointSetFn(t1), toPointSetFn(t2))
pointSets |> E.R2.bind(((a, b)) => pointSets |> E.R2.bind(((a, b)) =>
PointSetDist.T.logScore(a, b)->E.R2.errMap(x => DistributionTypes.OperationError(x)) PointSetDist.T.klDivergence(a, b)->E.R2.errMap(x => DistributionTypes.OperationError(x))
) )
} }
@ -391,14 +391,12 @@ let pointwiseCombinationFloat = (
~algebraicCombination: Operation.algebraicOperation, ~algebraicCombination: Operation.algebraicOperation,
~f: float, ~f: float,
): result<t, error> => { ): result<t, error> => {
let m = switch algebraicCombination { let executeCombination = arithOp =>
| #Add | #Subtract => Error(DistributionTypes.DistributionVerticalShiftIsInvalid)
| (#Multiply | #Divide | #Power | #Logarithm) as arithmeticOperation =>
toPointSetFn(t)->E.R.bind(t => { toPointSetFn(t)->E.R.bind(t => {
//TODO: Move to PointSet codebase //TODO: Move to PointSet codebase
let fn = (secondary, main) => Operation.Scale.toFn(arithmeticOperation, main, secondary) let fn = (secondary, main) => Operation.Scale.toFn(arithOp, main, secondary)
let integralSumCacheFn = Operation.Scale.toIntegralSumCacheFn(arithmeticOperation) let integralSumCacheFn = Operation.Scale.toIntegralSumCacheFn(arithOp)
let integralCacheFn = Operation.Scale.toIntegralCacheFn(arithmeticOperation) let integralCacheFn = Operation.Scale.toIntegralCacheFn(arithOp)
PointSetDist.T.mapYResult( PointSetDist.T.mapYResult(
~integralSumCacheFn=integralSumCacheFn(f), ~integralSumCacheFn=integralSumCacheFn(f),
~integralCacheFn=integralCacheFn(f), ~integralCacheFn=integralCacheFn(f),
@ -406,6 +404,11 @@ let pointwiseCombinationFloat = (
t, t,
)->E.R2.errMap(x => DistributionTypes.OperationError(x)) )->E.R2.errMap(x => DistributionTypes.OperationError(x))
}) })
let m = switch algebraicCombination {
| #Add | #Subtract => Error(DistributionTypes.DistributionVerticalShiftIsInvalid)
| (#Multiply | #Divide | #Power | #Logarithm) as arithmeticOperation =>
executeCombination(arithmeticOperation)
| #LogarithmWithThreshold(eps) => executeCombination(#LogarithmWithThreshold(eps))
} }
m->E.R2.fmap(r => DistributionTypes.PointSet(r)) m->E.R2.fmap(r => DistributionTypes.PointSet(r))
} }

View File

@ -23,7 +23,7 @@ let toFloatOperation: (
~distToFloatOperation: DistributionTypes.DistributionOperation.toFloat, ~distToFloatOperation: DistributionTypes.DistributionOperation.toFloat,
) => result<float, error> ) => result<float, error>
let logScore: (t, t, ~toPointSetFn: toPointSetFn) => result<float, error> let klDivergence: (t, t, ~toPointSetFn: toPointSetFn) => result<float, error>
@genType @genType
let toPointSet: ( let toPointSet: (

View File

@ -270,7 +270,7 @@ module T = Dist({
let variance = (t: t): float => let variance = (t: t): float =>
XYShape.Analysis.getVarianceDangerously(t, mean, Analysis.getMeanOfSquares) XYShape.Analysis.getVarianceDangerously(t, mean, Analysis.getMeanOfSquares)
let logScore = (base: t, reference: t) => { let klDivergence = (base: t, reference: t) => {
let referenceIsZero = switch Distributions.Common.isZeroEverywhere( let referenceIsZero = switch Distributions.Common.isZeroEverywhere(
PointSetTypes.Continuous(reference), PointSetTypes.Continuous(reference),
) { ) {

View File

@ -229,7 +229,7 @@ module T = Dist({
XYShape.Analysis.getVarianceDangerously(t, mean, getMeanOfSquares) XYShape.Analysis.getVarianceDangerously(t, mean, getMeanOfSquares)
} }
let logScore = (base: t, reference: t) => { let klDivergence = (base: t, reference: t) => {
let referenceIsZero = switch Distributions.Common.isZeroEverywhere( let referenceIsZero = switch Distributions.Common.isZeroEverywhere(
PointSetTypes.Discrete(reference), PointSetTypes.Discrete(reference),
) { ) {

View File

@ -33,7 +33,7 @@ module type dist = {
let mean: t => float let mean: t => float
let variance: t => float let variance: t => float
let logScore: (t, t) => result<float, Operation.Error.t> let klDivergence: (t, t) => result<float, Operation.Error.t>
} }
module Dist = (T: dist) => { module Dist = (T: dist) => {
@ -56,7 +56,7 @@ module Dist = (T: dist) => {
let mean = T.mean let mean = T.mean
let variance = T.variance let variance = T.variance
let integralEndY = T.integralEndY let integralEndY = T.integralEndY
let logScore = T.logScore let klDivergence = T.klDivergence
let updateIntegralCache = T.updateIntegralCache let updateIntegralCache = T.updateIntegralCache

View File

@ -301,7 +301,7 @@ module T = Dist({
} }
} }
let logScore = (base: t, reference: t) => { let klDivergence = (base: t, reference: t) => {
let referenceIsZero = switch Distributions.Common.isZeroEverywhere( let referenceIsZero = switch Distributions.Common.isZeroEverywhere(
PointSetTypes.Mixed(reference), PointSetTypes.Mixed(reference),
) { ) {

View File

@ -196,15 +196,15 @@ module T = Dist({
| Continuous(m) => Continuous.T.variance(m) | Continuous(m) => Continuous.T.variance(m)
} }
let logScore = (t1: t, t2: t) => let klDivergence = (t1: t, t2: t) =>
switch (t1, t2) { switch (t1, t2) {
| (Continuous(t1), Continuous(t2)) => Continuous.T.logScore(t1, t2) | (Continuous(t1), Continuous(t2)) => Continuous.T.klDivergence(t1, t2)
| (Discrete(t1), Discrete(t2)) => Discrete.T.logScore(t1, t2) | (Discrete(t1), Discrete(t2)) => Discrete.T.klDivergence(t1, t2)
| (Mixed(t1), Mixed(t2)) => Mixed.T.logScore(t1, t2) | (Mixed(t1), Mixed(t2)) => Mixed.T.klDivergence(t1, t2)
| _ => { | _ => {
let t1 = toMixed(t1) let t1 = toMixed(t1)
let t2 = toMixed(t2) let t2 = toMixed(t2)
Mixed.T.logScore(t1, t2) Mixed.T.klDivergence(t1, t2)
} }
} }
}) })

View File

@ -222,6 +222,8 @@ let dispatchToGenericOutput = (call: ExpressionValue.functionCall, _environment)
| ("scaleLog10", [EvDistribution(dist)]) => Helpers.toDistFn(Scale(#Logarithm, 10.0), dist) | ("scaleLog10", [EvDistribution(dist)]) => Helpers.toDistFn(Scale(#Logarithm, 10.0), dist)
| ("scaleLog", [EvDistribution(dist), EvNumber(float)]) => | ("scaleLog", [EvDistribution(dist), EvNumber(float)]) =>
Helpers.toDistFn(Scale(#Logarithm, float), dist) Helpers.toDistFn(Scale(#Logarithm, float), dist)
| ("scaleLogWithThreshold", [EvDistribution(dist), EvNumber(base), EvNumber(eps)]) =>
Helpers.toDistFn(Scale(#LogarithmWithThreshold(eps), base), dist)
| ("scalePow", [EvDistribution(dist), EvNumber(float)]) => | ("scalePow", [EvDistribution(dist), EvNumber(float)]) =>
Helpers.toDistFn(Scale(#Power, float), dist) Helpers.toDistFn(Scale(#Power, float), dist)
| ("scaleExp", [EvDistribution(dist)]) => | ("scaleExp", [EvDistribution(dist)]) =>

View File

@ -8,6 +8,7 @@ type algebraicOperation = [
| #Divide | #Divide
| #Power | #Power
| #Logarithm | #Logarithm
| #LogarithmWithThreshold(float)
] ]
type convolutionOperation = [ type convolutionOperation = [
@ -18,7 +19,7 @@ type convolutionOperation = [
@genType @genType
type pointwiseOperation = [#Add | #Multiply | #Power] type pointwiseOperation = [#Add | #Multiply | #Power]
type scaleOperation = [#Multiply | #Power | #Logarithm | #Divide] type scaleOperation = [#Multiply | #Power | #Logarithm | #LogarithmWithThreshold(float) | #Divide]
type distToFloatOperation = [ type distToFloatOperation = [
| #Pdf(float) | #Pdf(float)
| #Cdf(float) | #Cdf(float)
@ -35,7 +36,7 @@ module Convolution = {
| #Add => Some(#Add) | #Add => Some(#Add)
| #Subtract => Some(#Subtract) | #Subtract => Some(#Subtract)
| #Multiply => Some(#Multiply) | #Multiply => Some(#Multiply)
| #Divide | #Power | #Logarithm => None | #Divide | #Power | #Logarithm | #LogarithmWithThreshold(_) => None
} }
let canDoAlgebraicOperation = (op: algebraicOperation): bool => let canDoAlgebraicOperation = (op: algebraicOperation): bool =>
@ -108,6 +109,12 @@ module Algebraic = {
| #Power => power(a, b) | #Power => power(a, b)
| #Divide => divide(a, b) | #Divide => divide(a, b)
| #Logarithm => logarithm(a, b) | #Logarithm => logarithm(a, b)
| #LogarithmWithThreshold(eps) =>
if a < eps {
Ok(0.0)
} else {
logarithm(a, b)
}
} }
let toString = x => let toString = x =>
@ -118,6 +125,7 @@ module Algebraic = {
| #Power => "**" | #Power => "**"
| #Divide => "/" | #Divide => "/"
| #Logarithm => "log" | #Logarithm => "log"
| #LogarithmWithThreshold(_) => "log"
} }
let format = (a, b, c) => b ++ (" " ++ (toString(a) ++ (" " ++ c))) let format = (a, b, c) => b ++ (" " ++ (toString(a) ++ (" " ++ c)))
@ -162,6 +170,12 @@ module Scale = {
} else { } else {
logarithm(a, b) logarithm(a, b)
} }
| #LogarithmWithThreshold(eps) =>
if a < eps {
Ok(0.0)
} else {
logarithm(a, b)
}
} }
let format = (operation: t, value, scaleBy) => let format = (operation: t, value, scaleBy) =>
@ -170,14 +184,14 @@ module Scale = {
| #Divide => j`verticalDivide($value, $scaleBy) ` | #Divide => j`verticalDivide($value, $scaleBy) `
| #Power => j`verticalPower($value, $scaleBy) ` | #Power => j`verticalPower($value, $scaleBy) `
| #Logarithm => j`verticalLog($value, $scaleBy) ` | #Logarithm => j`verticalLog($value, $scaleBy) `
| #LogarithmWithThreshold(eps) => j`verticalLog($value, $scaleBy, epsilon=$eps) `
} }
let toIntegralSumCacheFn = x => let toIntegralSumCacheFn = x =>
switch x { switch x {
| #Multiply => (a, b) => Some(a *. b) | #Multiply => (a, b) => Some(a *. b)
| #Divide => (a, b) => Some(a /. b) | #Divide => (a, b) => Some(a /. b)
| #Power => (_, _) => None | #Power | #Logarithm | #LogarithmWithThreshold(_) => (_, _) => None
| #Logarithm => (_, _) => None
} }
let toIntegralCacheFn = x => let toIntegralCacheFn = x =>
@ -186,6 +200,7 @@ module Scale = {
| #Divide => (_, _) => None | #Divide => (_, _) => None
| #Power => (_, _) => None | #Power => (_, _) => None
| #Logarithm => (_, _) => None | #Logarithm => (_, _) => None
| #LogarithmWithThreshold(_) => (_, _) => None
} }
} }