Added logScaleWithThreshold(eps) and completed renaming to

`klDivergence`

Value: [1e-5 to 1e-3]
This commit is contained in:
Quinn Dougherty 2022-05-04 13:02:58 -04:00
parent 236be470d5
commit c95c56cfb8
11 changed files with 60 additions and 24 deletions

View File

@ -145,7 +145,7 @@ let rec run = (~env, functionCallInfo: functionCallInfo): outputType => {
}
| ToDist(Normalize) => dist->GenericDist.normalize->Dist
| ToScore(KLDivergence(t2)) =>
GenericDist.logScore(dist, t2, ~toPointSetFn)
GenericDist.klDivergence(dist, t2, ~toPointSetFn)
->E.R2.fmap(r => Float(r))
->OutputLocal.fromResult
| ToBool(IsNormalized) => dist->GenericDist.isNormalized->Bool
@ -163,6 +163,15 @@ let rec run = (~env, functionCallInfo: functionCallInfo): outputType => {
->GenericDist.toPointSet(~xyPointLength, ~sampleCount, ())
->E.R2.fmap(r => Dist(PointSet(r)))
->OutputLocal.fromResult
| ToDist(Scale(#LogarithmWithThreshold(eps), f)) =>
dist
->GenericDist.pointwiseCombinationFloat(
~toPointSetFn,
~algebraicCombination=#LogarithmWithThreshold(eps),
~f,
)
->E.R2.fmap(r => Dist(r))
->OutputLocal.fromResult
| ToDist(Scale(#Logarithm, f)) =>
dist
->GenericDist.pointwiseCombinationFloat(~toPointSetFn, ~algebraicCombination=#Logarithm, ~f)

View File

@ -72,6 +72,7 @@ module DistributionOperation = {
type toScaleFn = [
| #Power
| #Logarithm
| #LogarithmWithThreshold(float)
]
type toDist =
@ -126,6 +127,8 @@ module DistributionOperation = {
| ToDist(Inspect) => `inspect`
| ToDist(Scale(#Power, r)) => `scalePower(${E.Float.toFixed(r)})`
| ToDist(Scale(#Logarithm, r)) => `scaleLog(${E.Float.toFixed(r)})`
| ToDist(Scale(#LogarithmWithThreshold(eps), r)) =>
`scaleLogWithThreshold(${E.Float.toFixed(r)}, epsilon=${E.Float.toFixed(eps)})`
| ToString(ToString) => `toString`
| ToString(ToSparkline(n)) => `toSparkline(${E.I.toString(n)})`
| ToBool(IsNormalized) => `isNormalized`
@ -160,6 +163,10 @@ module Constructors = {
let logScore = (dist1, dist2): t => FromDist(ToScore(KLDivergence(dist2)), dist1)
let scalePower = (dist, n): t => FromDist(ToDist(Scale(#Power, n)), dist)
let scaleLogarithm = (dist, n): t => FromDist(ToDist(Scale(#Logarithm, n)), dist)
let scaleLogarithmWithThreshold = (dist, n, eps): t => FromDist(
ToDist(Scale(#LogarithmWithThreshold(eps), n)),
dist,
)
let toString = (dist): t => FromDist(ToString(ToString), dist)
let toSparkline = (dist, n): t => FromDist(ToString(ToSparkline(n)), dist)
let algebraicAdd = (dist1, dist2: genericDist): t => FromDist(

View File

@ -59,10 +59,10 @@ let integralEndY = (t: t): float =>
let isNormalized = (t: t): bool => Js.Math.abs_float(integralEndY(t) -. 1.0) < 1e-7
let logScore = (t1, t2, ~toPointSetFn: toPointSetFn): result<float, error> => {
let klDivergence = (t1, t2, ~toPointSetFn: toPointSetFn): result<float, error> => {
let pointSets = E.R.merge(toPointSetFn(t1), toPointSetFn(t2))
pointSets |> E.R2.bind(((a, b)) =>
PointSetDist.T.logScore(a, b)->E.R2.errMap(x => DistributionTypes.OperationError(x))
PointSetDist.T.klDivergence(a, b)->E.R2.errMap(x => DistributionTypes.OperationError(x))
)
}
@ -391,14 +391,12 @@ let pointwiseCombinationFloat = (
~algebraicCombination: Operation.algebraicOperation,
~f: float,
): result<t, error> => {
let m = switch algebraicCombination {
| #Add | #Subtract => Error(DistributionTypes.DistributionVerticalShiftIsInvalid)
| (#Multiply | #Divide | #Power | #Logarithm) as arithmeticOperation =>
let executeCombination = arithOp =>
toPointSetFn(t)->E.R.bind(t => {
//TODO: Move to PointSet codebase
let fn = (secondary, main) => Operation.Scale.toFn(arithmeticOperation, main, secondary)
let integralSumCacheFn = Operation.Scale.toIntegralSumCacheFn(arithmeticOperation)
let integralCacheFn = Operation.Scale.toIntegralCacheFn(arithmeticOperation)
let fn = (secondary, main) => Operation.Scale.toFn(arithOp, main, secondary)
let integralSumCacheFn = Operation.Scale.toIntegralSumCacheFn(arithOp)
let integralCacheFn = Operation.Scale.toIntegralCacheFn(arithOp)
PointSetDist.T.mapYResult(
~integralSumCacheFn=integralSumCacheFn(f),
~integralCacheFn=integralCacheFn(f),
@ -406,6 +404,11 @@ let pointwiseCombinationFloat = (
t,
)->E.R2.errMap(x => DistributionTypes.OperationError(x))
})
let m = switch algebraicCombination {
| #Add | #Subtract => Error(DistributionTypes.DistributionVerticalShiftIsInvalid)
| (#Multiply | #Divide | #Power | #Logarithm) as arithmeticOperation =>
executeCombination(arithmeticOperation)
| #LogarithmWithThreshold(eps) => executeCombination(#LogarithmWithThreshold(eps))
}
m->E.R2.fmap(r => DistributionTypes.PointSet(r))
}

View File

@ -23,7 +23,7 @@ let toFloatOperation: (
~distToFloatOperation: DistributionTypes.DistributionOperation.toFloat,
) => result<float, error>
let logScore: (t, t, ~toPointSetFn: toPointSetFn) => result<float, error>
let klDivergence: (t, t, ~toPointSetFn: toPointSetFn) => result<float, error>
@genType
let toPointSet: (

View File

@ -270,7 +270,7 @@ module T = Dist({
let variance = (t: t): float =>
XYShape.Analysis.getVarianceDangerously(t, mean, Analysis.getMeanOfSquares)
let logScore = (base: t, reference: t) => {
let klDivergence = (base: t, reference: t) => {
let referenceIsZero = switch Distributions.Common.isZeroEverywhere(
PointSetTypes.Continuous(reference),
) {

View File

@ -229,7 +229,7 @@ module T = Dist({
XYShape.Analysis.getVarianceDangerously(t, mean, getMeanOfSquares)
}
let logScore = (base: t, reference: t) => {
let klDivergence = (base: t, reference: t) => {
let referenceIsZero = switch Distributions.Common.isZeroEverywhere(
PointSetTypes.Discrete(reference),
) {

View File

@ -33,7 +33,7 @@ module type dist = {
let mean: t => float
let variance: t => float
let logScore: (t, t) => result<float, Operation.Error.t>
let klDivergence: (t, t) => result<float, Operation.Error.t>
}
module Dist = (T: dist) => {
@ -56,7 +56,7 @@ module Dist = (T: dist) => {
let mean = T.mean
let variance = T.variance
let integralEndY = T.integralEndY
let logScore = T.logScore
let klDivergence = T.klDivergence
let updateIntegralCache = T.updateIntegralCache

View File

@ -301,7 +301,7 @@ module T = Dist({
}
}
let logScore = (base: t, reference: t) => {
let klDivergence = (base: t, reference: t) => {
let referenceIsZero = switch Distributions.Common.isZeroEverywhere(
PointSetTypes.Mixed(reference),
) {

View File

@ -196,15 +196,15 @@ module T = Dist({
| Continuous(m) => Continuous.T.variance(m)
}
let logScore = (t1: t, t2: t) =>
let klDivergence = (t1: t, t2: t) =>
switch (t1, t2) {
| (Continuous(t1), Continuous(t2)) => Continuous.T.logScore(t1, t2)
| (Discrete(t1), Discrete(t2)) => Discrete.T.logScore(t1, t2)
| (Mixed(t1), Mixed(t2)) => Mixed.T.logScore(t1, t2)
| (Continuous(t1), Continuous(t2)) => Continuous.T.klDivergence(t1, t2)
| (Discrete(t1), Discrete(t2)) => Discrete.T.klDivergence(t1, t2)
| (Mixed(t1), Mixed(t2)) => Mixed.T.klDivergence(t1, t2)
| _ => {
let t1 = toMixed(t1)
let t2 = toMixed(t2)
Mixed.T.logScore(t1, t2)
Mixed.T.klDivergence(t1, t2)
}
}
})

View File

@ -222,6 +222,8 @@ let dispatchToGenericOutput = (call: ExpressionValue.functionCall, _environment)
| ("scaleLog10", [EvDistribution(dist)]) => Helpers.toDistFn(Scale(#Logarithm, 10.0), dist)
| ("scaleLog", [EvDistribution(dist), EvNumber(float)]) =>
Helpers.toDistFn(Scale(#Logarithm, float), dist)
| ("scaleLogWithThreshold", [EvDistribution(dist), EvNumber(base), EvNumber(eps)]) =>
Helpers.toDistFn(Scale(#LogarithmWithThreshold(eps), base), dist)
| ("scalePow", [EvDistribution(dist), EvNumber(float)]) =>
Helpers.toDistFn(Scale(#Power, float), dist)
| ("scaleExp", [EvDistribution(dist)]) =>

View File

@ -8,6 +8,7 @@ type algebraicOperation = [
| #Divide
| #Power
| #Logarithm
| #LogarithmWithThreshold(float)
]
type convolutionOperation = [
@ -18,7 +19,7 @@ type convolutionOperation = [
@genType
type pointwiseOperation = [#Add | #Multiply | #Power]
type scaleOperation = [#Multiply | #Power | #Logarithm | #Divide]
type scaleOperation = [#Multiply | #Power | #Logarithm | #LogarithmWithThreshold(float) | #Divide]
type distToFloatOperation = [
| #Pdf(float)
| #Cdf(float)
@ -35,7 +36,7 @@ module Convolution = {
| #Add => Some(#Add)
| #Subtract => Some(#Subtract)
| #Multiply => Some(#Multiply)
| #Divide | #Power | #Logarithm => None
| #Divide | #Power | #Logarithm | #LogarithmWithThreshold(_) => None
}
let canDoAlgebraicOperation = (op: algebraicOperation): bool =>
@ -108,6 +109,12 @@ module Algebraic = {
| #Power => power(a, b)
| #Divide => divide(a, b)
| #Logarithm => logarithm(a, b)
| #LogarithmWithThreshold(eps) =>
if a < eps {
Ok(0.0)
} else {
logarithm(a, b)
}
}
let toString = x =>
@ -118,6 +125,7 @@ module Algebraic = {
| #Power => "**"
| #Divide => "/"
| #Logarithm => "log"
| #LogarithmWithThreshold(_) => "log"
}
let format = (a, b, c) => b ++ (" " ++ (toString(a) ++ (" " ++ c)))
@ -162,6 +170,12 @@ module Scale = {
} else {
logarithm(a, b)
}
| #LogarithmWithThreshold(eps) =>
if a < eps {
Ok(0.0)
} else {
logarithm(a, b)
}
}
let format = (operation: t, value, scaleBy) =>
@ -170,14 +184,14 @@ module Scale = {
| #Divide => j`verticalDivide($value, $scaleBy) `
| #Power => j`verticalPower($value, $scaleBy) `
| #Logarithm => j`verticalLog($value, $scaleBy) `
| #LogarithmWithThreshold(eps) => j`verticalLog($value, $scaleBy, epsilon=$eps) `
}
let toIntegralSumCacheFn = x =>
switch x {
| #Multiply => (a, b) => Some(a *. b)
| #Divide => (a, b) => Some(a /. b)
| #Power => (_, _) => None
| #Logarithm => (_, _) => None
| #Power | #Logarithm | #LogarithmWithThreshold(_) => (_, _) => None
}
let toIntegralCacheFn = x =>
@ -186,6 +200,7 @@ module Scale = {
| #Divide => (_, _) => None
| #Power => (_, _) => None
| #Logarithm => (_, _) => None
| #LogarithmWithThreshold(_) => (_, _) => None
}
}