Added logScaleWithThreshold(eps)
and completed renaming to
`klDivergence` Value: [1e-5 to 1e-3]
This commit is contained in:
parent
236be470d5
commit
c95c56cfb8
|
@ -145,7 +145,7 @@ let rec run = (~env, functionCallInfo: functionCallInfo): outputType => {
|
||||||
}
|
}
|
||||||
| ToDist(Normalize) => dist->GenericDist.normalize->Dist
|
| ToDist(Normalize) => dist->GenericDist.normalize->Dist
|
||||||
| ToScore(KLDivergence(t2)) =>
|
| ToScore(KLDivergence(t2)) =>
|
||||||
GenericDist.logScore(dist, t2, ~toPointSetFn)
|
GenericDist.klDivergence(dist, t2, ~toPointSetFn)
|
||||||
->E.R2.fmap(r => Float(r))
|
->E.R2.fmap(r => Float(r))
|
||||||
->OutputLocal.fromResult
|
->OutputLocal.fromResult
|
||||||
| ToBool(IsNormalized) => dist->GenericDist.isNormalized->Bool
|
| ToBool(IsNormalized) => dist->GenericDist.isNormalized->Bool
|
||||||
|
@ -163,6 +163,15 @@ let rec run = (~env, functionCallInfo: functionCallInfo): outputType => {
|
||||||
->GenericDist.toPointSet(~xyPointLength, ~sampleCount, ())
|
->GenericDist.toPointSet(~xyPointLength, ~sampleCount, ())
|
||||||
->E.R2.fmap(r => Dist(PointSet(r)))
|
->E.R2.fmap(r => Dist(PointSet(r)))
|
||||||
->OutputLocal.fromResult
|
->OutputLocal.fromResult
|
||||||
|
| ToDist(Scale(#LogarithmWithThreshold(eps), f)) =>
|
||||||
|
dist
|
||||||
|
->GenericDist.pointwiseCombinationFloat(
|
||||||
|
~toPointSetFn,
|
||||||
|
~algebraicCombination=#LogarithmWithThreshold(eps),
|
||||||
|
~f,
|
||||||
|
)
|
||||||
|
->E.R2.fmap(r => Dist(r))
|
||||||
|
->OutputLocal.fromResult
|
||||||
| ToDist(Scale(#Logarithm, f)) =>
|
| ToDist(Scale(#Logarithm, f)) =>
|
||||||
dist
|
dist
|
||||||
->GenericDist.pointwiseCombinationFloat(~toPointSetFn, ~algebraicCombination=#Logarithm, ~f)
|
->GenericDist.pointwiseCombinationFloat(~toPointSetFn, ~algebraicCombination=#Logarithm, ~f)
|
||||||
|
|
|
@ -72,6 +72,7 @@ module DistributionOperation = {
|
||||||
type toScaleFn = [
|
type toScaleFn = [
|
||||||
| #Power
|
| #Power
|
||||||
| #Logarithm
|
| #Logarithm
|
||||||
|
| #LogarithmWithThreshold(float)
|
||||||
]
|
]
|
||||||
|
|
||||||
type toDist =
|
type toDist =
|
||||||
|
@ -126,6 +127,8 @@ module DistributionOperation = {
|
||||||
| ToDist(Inspect) => `inspect`
|
| ToDist(Inspect) => `inspect`
|
||||||
| ToDist(Scale(#Power, r)) => `scalePower(${E.Float.toFixed(r)})`
|
| ToDist(Scale(#Power, r)) => `scalePower(${E.Float.toFixed(r)})`
|
||||||
| ToDist(Scale(#Logarithm, r)) => `scaleLog(${E.Float.toFixed(r)})`
|
| ToDist(Scale(#Logarithm, r)) => `scaleLog(${E.Float.toFixed(r)})`
|
||||||
|
| ToDist(Scale(#LogarithmWithThreshold(eps), r)) =>
|
||||||
|
`scaleLogWithThreshold(${E.Float.toFixed(r)}, epsilon=${E.Float.toFixed(eps)})`
|
||||||
| ToString(ToString) => `toString`
|
| ToString(ToString) => `toString`
|
||||||
| ToString(ToSparkline(n)) => `toSparkline(${E.I.toString(n)})`
|
| ToString(ToSparkline(n)) => `toSparkline(${E.I.toString(n)})`
|
||||||
| ToBool(IsNormalized) => `isNormalized`
|
| ToBool(IsNormalized) => `isNormalized`
|
||||||
|
@ -160,6 +163,10 @@ module Constructors = {
|
||||||
let logScore = (dist1, dist2): t => FromDist(ToScore(KLDivergence(dist2)), dist1)
|
let logScore = (dist1, dist2): t => FromDist(ToScore(KLDivergence(dist2)), dist1)
|
||||||
let scalePower = (dist, n): t => FromDist(ToDist(Scale(#Power, n)), dist)
|
let scalePower = (dist, n): t => FromDist(ToDist(Scale(#Power, n)), dist)
|
||||||
let scaleLogarithm = (dist, n): t => FromDist(ToDist(Scale(#Logarithm, n)), dist)
|
let scaleLogarithm = (dist, n): t => FromDist(ToDist(Scale(#Logarithm, n)), dist)
|
||||||
|
let scaleLogarithmWithThreshold = (dist, n, eps): t => FromDist(
|
||||||
|
ToDist(Scale(#LogarithmWithThreshold(eps), n)),
|
||||||
|
dist,
|
||||||
|
)
|
||||||
let toString = (dist): t => FromDist(ToString(ToString), dist)
|
let toString = (dist): t => FromDist(ToString(ToString), dist)
|
||||||
let toSparkline = (dist, n): t => FromDist(ToString(ToSparkline(n)), dist)
|
let toSparkline = (dist, n): t => FromDist(ToString(ToSparkline(n)), dist)
|
||||||
let algebraicAdd = (dist1, dist2: genericDist): t => FromDist(
|
let algebraicAdd = (dist1, dist2: genericDist): t => FromDist(
|
||||||
|
|
|
@ -59,10 +59,10 @@ let integralEndY = (t: t): float =>
|
||||||
|
|
||||||
let isNormalized = (t: t): bool => Js.Math.abs_float(integralEndY(t) -. 1.0) < 1e-7
|
let isNormalized = (t: t): bool => Js.Math.abs_float(integralEndY(t) -. 1.0) < 1e-7
|
||||||
|
|
||||||
let logScore = (t1, t2, ~toPointSetFn: toPointSetFn): result<float, error> => {
|
let klDivergence = (t1, t2, ~toPointSetFn: toPointSetFn): result<float, error> => {
|
||||||
let pointSets = E.R.merge(toPointSetFn(t1), toPointSetFn(t2))
|
let pointSets = E.R.merge(toPointSetFn(t1), toPointSetFn(t2))
|
||||||
pointSets |> E.R2.bind(((a, b)) =>
|
pointSets |> E.R2.bind(((a, b)) =>
|
||||||
PointSetDist.T.logScore(a, b)->E.R2.errMap(x => DistributionTypes.OperationError(x))
|
PointSetDist.T.klDivergence(a, b)->E.R2.errMap(x => DistributionTypes.OperationError(x))
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -391,14 +391,12 @@ let pointwiseCombinationFloat = (
|
||||||
~algebraicCombination: Operation.algebraicOperation,
|
~algebraicCombination: Operation.algebraicOperation,
|
||||||
~f: float,
|
~f: float,
|
||||||
): result<t, error> => {
|
): result<t, error> => {
|
||||||
let m = switch algebraicCombination {
|
let executeCombination = arithOp =>
|
||||||
| #Add | #Subtract => Error(DistributionTypes.DistributionVerticalShiftIsInvalid)
|
|
||||||
| (#Multiply | #Divide | #Power | #Logarithm) as arithmeticOperation =>
|
|
||||||
toPointSetFn(t)->E.R.bind(t => {
|
toPointSetFn(t)->E.R.bind(t => {
|
||||||
//TODO: Move to PointSet codebase
|
//TODO: Move to PointSet codebase
|
||||||
let fn = (secondary, main) => Operation.Scale.toFn(arithmeticOperation, main, secondary)
|
let fn = (secondary, main) => Operation.Scale.toFn(arithOp, main, secondary)
|
||||||
let integralSumCacheFn = Operation.Scale.toIntegralSumCacheFn(arithmeticOperation)
|
let integralSumCacheFn = Operation.Scale.toIntegralSumCacheFn(arithOp)
|
||||||
let integralCacheFn = Operation.Scale.toIntegralCacheFn(arithmeticOperation)
|
let integralCacheFn = Operation.Scale.toIntegralCacheFn(arithOp)
|
||||||
PointSetDist.T.mapYResult(
|
PointSetDist.T.mapYResult(
|
||||||
~integralSumCacheFn=integralSumCacheFn(f),
|
~integralSumCacheFn=integralSumCacheFn(f),
|
||||||
~integralCacheFn=integralCacheFn(f),
|
~integralCacheFn=integralCacheFn(f),
|
||||||
|
@ -406,6 +404,11 @@ let pointwiseCombinationFloat = (
|
||||||
t,
|
t,
|
||||||
)->E.R2.errMap(x => DistributionTypes.OperationError(x))
|
)->E.R2.errMap(x => DistributionTypes.OperationError(x))
|
||||||
})
|
})
|
||||||
|
let m = switch algebraicCombination {
|
||||||
|
| #Add | #Subtract => Error(DistributionTypes.DistributionVerticalShiftIsInvalid)
|
||||||
|
| (#Multiply | #Divide | #Power | #Logarithm) as arithmeticOperation =>
|
||||||
|
executeCombination(arithmeticOperation)
|
||||||
|
| #LogarithmWithThreshold(eps) => executeCombination(#LogarithmWithThreshold(eps))
|
||||||
}
|
}
|
||||||
m->E.R2.fmap(r => DistributionTypes.PointSet(r))
|
m->E.R2.fmap(r => DistributionTypes.PointSet(r))
|
||||||
}
|
}
|
||||||
|
|
|
@ -23,7 +23,7 @@ let toFloatOperation: (
|
||||||
~distToFloatOperation: DistributionTypes.DistributionOperation.toFloat,
|
~distToFloatOperation: DistributionTypes.DistributionOperation.toFloat,
|
||||||
) => result<float, error>
|
) => result<float, error>
|
||||||
|
|
||||||
let logScore: (t, t, ~toPointSetFn: toPointSetFn) => result<float, error>
|
let klDivergence: (t, t, ~toPointSetFn: toPointSetFn) => result<float, error>
|
||||||
|
|
||||||
@genType
|
@genType
|
||||||
let toPointSet: (
|
let toPointSet: (
|
||||||
|
|
|
@ -270,7 +270,7 @@ module T = Dist({
|
||||||
let variance = (t: t): float =>
|
let variance = (t: t): float =>
|
||||||
XYShape.Analysis.getVarianceDangerously(t, mean, Analysis.getMeanOfSquares)
|
XYShape.Analysis.getVarianceDangerously(t, mean, Analysis.getMeanOfSquares)
|
||||||
|
|
||||||
let logScore = (base: t, reference: t) => {
|
let klDivergence = (base: t, reference: t) => {
|
||||||
let referenceIsZero = switch Distributions.Common.isZeroEverywhere(
|
let referenceIsZero = switch Distributions.Common.isZeroEverywhere(
|
||||||
PointSetTypes.Continuous(reference),
|
PointSetTypes.Continuous(reference),
|
||||||
) {
|
) {
|
||||||
|
|
|
@ -229,7 +229,7 @@ module T = Dist({
|
||||||
XYShape.Analysis.getVarianceDangerously(t, mean, getMeanOfSquares)
|
XYShape.Analysis.getVarianceDangerously(t, mean, getMeanOfSquares)
|
||||||
}
|
}
|
||||||
|
|
||||||
let logScore = (base: t, reference: t) => {
|
let klDivergence = (base: t, reference: t) => {
|
||||||
let referenceIsZero = switch Distributions.Common.isZeroEverywhere(
|
let referenceIsZero = switch Distributions.Common.isZeroEverywhere(
|
||||||
PointSetTypes.Discrete(reference),
|
PointSetTypes.Discrete(reference),
|
||||||
) {
|
) {
|
||||||
|
|
|
@ -33,7 +33,7 @@ module type dist = {
|
||||||
|
|
||||||
let mean: t => float
|
let mean: t => float
|
||||||
let variance: t => float
|
let variance: t => float
|
||||||
let logScore: (t, t) => result<float, Operation.Error.t>
|
let klDivergence: (t, t) => result<float, Operation.Error.t>
|
||||||
}
|
}
|
||||||
|
|
||||||
module Dist = (T: dist) => {
|
module Dist = (T: dist) => {
|
||||||
|
@ -56,7 +56,7 @@ module Dist = (T: dist) => {
|
||||||
let mean = T.mean
|
let mean = T.mean
|
||||||
let variance = T.variance
|
let variance = T.variance
|
||||||
let integralEndY = T.integralEndY
|
let integralEndY = T.integralEndY
|
||||||
let logScore = T.logScore
|
let klDivergence = T.klDivergence
|
||||||
|
|
||||||
let updateIntegralCache = T.updateIntegralCache
|
let updateIntegralCache = T.updateIntegralCache
|
||||||
|
|
||||||
|
|
|
@ -301,7 +301,7 @@ module T = Dist({
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
let logScore = (base: t, reference: t) => {
|
let klDivergence = (base: t, reference: t) => {
|
||||||
let referenceIsZero = switch Distributions.Common.isZeroEverywhere(
|
let referenceIsZero = switch Distributions.Common.isZeroEverywhere(
|
||||||
PointSetTypes.Mixed(reference),
|
PointSetTypes.Mixed(reference),
|
||||||
) {
|
) {
|
||||||
|
|
|
@ -196,15 +196,15 @@ module T = Dist({
|
||||||
| Continuous(m) => Continuous.T.variance(m)
|
| Continuous(m) => Continuous.T.variance(m)
|
||||||
}
|
}
|
||||||
|
|
||||||
let logScore = (t1: t, t2: t) =>
|
let klDivergence = (t1: t, t2: t) =>
|
||||||
switch (t1, t2) {
|
switch (t1, t2) {
|
||||||
| (Continuous(t1), Continuous(t2)) => Continuous.T.logScore(t1, t2)
|
| (Continuous(t1), Continuous(t2)) => Continuous.T.klDivergence(t1, t2)
|
||||||
| (Discrete(t1), Discrete(t2)) => Discrete.T.logScore(t1, t2)
|
| (Discrete(t1), Discrete(t2)) => Discrete.T.klDivergence(t1, t2)
|
||||||
| (Mixed(t1), Mixed(t2)) => Mixed.T.logScore(t1, t2)
|
| (Mixed(t1), Mixed(t2)) => Mixed.T.klDivergence(t1, t2)
|
||||||
| _ => {
|
| _ => {
|
||||||
let t1 = toMixed(t1)
|
let t1 = toMixed(t1)
|
||||||
let t2 = toMixed(t2)
|
let t2 = toMixed(t2)
|
||||||
Mixed.T.logScore(t1, t2)
|
Mixed.T.klDivergence(t1, t2)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
|
@ -222,6 +222,8 @@ let dispatchToGenericOutput = (call: ExpressionValue.functionCall, _environment)
|
||||||
| ("scaleLog10", [EvDistribution(dist)]) => Helpers.toDistFn(Scale(#Logarithm, 10.0), dist)
|
| ("scaleLog10", [EvDistribution(dist)]) => Helpers.toDistFn(Scale(#Logarithm, 10.0), dist)
|
||||||
| ("scaleLog", [EvDistribution(dist), EvNumber(float)]) =>
|
| ("scaleLog", [EvDistribution(dist), EvNumber(float)]) =>
|
||||||
Helpers.toDistFn(Scale(#Logarithm, float), dist)
|
Helpers.toDistFn(Scale(#Logarithm, float), dist)
|
||||||
|
| ("scaleLogWithThreshold", [EvDistribution(dist), EvNumber(base), EvNumber(eps)]) =>
|
||||||
|
Helpers.toDistFn(Scale(#LogarithmWithThreshold(eps), base), dist)
|
||||||
| ("scalePow", [EvDistribution(dist), EvNumber(float)]) =>
|
| ("scalePow", [EvDistribution(dist), EvNumber(float)]) =>
|
||||||
Helpers.toDistFn(Scale(#Power, float), dist)
|
Helpers.toDistFn(Scale(#Power, float), dist)
|
||||||
| ("scaleExp", [EvDistribution(dist)]) =>
|
| ("scaleExp", [EvDistribution(dist)]) =>
|
||||||
|
|
|
@ -8,6 +8,7 @@ type algebraicOperation = [
|
||||||
| #Divide
|
| #Divide
|
||||||
| #Power
|
| #Power
|
||||||
| #Logarithm
|
| #Logarithm
|
||||||
|
| #LogarithmWithThreshold(float)
|
||||||
]
|
]
|
||||||
|
|
||||||
type convolutionOperation = [
|
type convolutionOperation = [
|
||||||
|
@ -18,7 +19,7 @@ type convolutionOperation = [
|
||||||
|
|
||||||
@genType
|
@genType
|
||||||
type pointwiseOperation = [#Add | #Multiply | #Power]
|
type pointwiseOperation = [#Add | #Multiply | #Power]
|
||||||
type scaleOperation = [#Multiply | #Power | #Logarithm | #Divide]
|
type scaleOperation = [#Multiply | #Power | #Logarithm | #LogarithmWithThreshold(float) | #Divide]
|
||||||
type distToFloatOperation = [
|
type distToFloatOperation = [
|
||||||
| #Pdf(float)
|
| #Pdf(float)
|
||||||
| #Cdf(float)
|
| #Cdf(float)
|
||||||
|
@ -35,7 +36,7 @@ module Convolution = {
|
||||||
| #Add => Some(#Add)
|
| #Add => Some(#Add)
|
||||||
| #Subtract => Some(#Subtract)
|
| #Subtract => Some(#Subtract)
|
||||||
| #Multiply => Some(#Multiply)
|
| #Multiply => Some(#Multiply)
|
||||||
| #Divide | #Power | #Logarithm => None
|
| #Divide | #Power | #Logarithm | #LogarithmWithThreshold(_) => None
|
||||||
}
|
}
|
||||||
|
|
||||||
let canDoAlgebraicOperation = (op: algebraicOperation): bool =>
|
let canDoAlgebraicOperation = (op: algebraicOperation): bool =>
|
||||||
|
@ -108,6 +109,12 @@ module Algebraic = {
|
||||||
| #Power => power(a, b)
|
| #Power => power(a, b)
|
||||||
| #Divide => divide(a, b)
|
| #Divide => divide(a, b)
|
||||||
| #Logarithm => logarithm(a, b)
|
| #Logarithm => logarithm(a, b)
|
||||||
|
| #LogarithmWithThreshold(eps) =>
|
||||||
|
if a < eps {
|
||||||
|
Ok(0.0)
|
||||||
|
} else {
|
||||||
|
logarithm(a, b)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
let toString = x =>
|
let toString = x =>
|
||||||
|
@ -118,6 +125,7 @@ module Algebraic = {
|
||||||
| #Power => "**"
|
| #Power => "**"
|
||||||
| #Divide => "/"
|
| #Divide => "/"
|
||||||
| #Logarithm => "log"
|
| #Logarithm => "log"
|
||||||
|
| #LogarithmWithThreshold(_) => "log"
|
||||||
}
|
}
|
||||||
|
|
||||||
let format = (a, b, c) => b ++ (" " ++ (toString(a) ++ (" " ++ c)))
|
let format = (a, b, c) => b ++ (" " ++ (toString(a) ++ (" " ++ c)))
|
||||||
|
@ -162,6 +170,12 @@ module Scale = {
|
||||||
} else {
|
} else {
|
||||||
logarithm(a, b)
|
logarithm(a, b)
|
||||||
}
|
}
|
||||||
|
| #LogarithmWithThreshold(eps) =>
|
||||||
|
if a < eps {
|
||||||
|
Ok(0.0)
|
||||||
|
} else {
|
||||||
|
logarithm(a, b)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
let format = (operation: t, value, scaleBy) =>
|
let format = (operation: t, value, scaleBy) =>
|
||||||
|
@ -170,14 +184,14 @@ module Scale = {
|
||||||
| #Divide => j`verticalDivide($value, $scaleBy) `
|
| #Divide => j`verticalDivide($value, $scaleBy) `
|
||||||
| #Power => j`verticalPower($value, $scaleBy) `
|
| #Power => j`verticalPower($value, $scaleBy) `
|
||||||
| #Logarithm => j`verticalLog($value, $scaleBy) `
|
| #Logarithm => j`verticalLog($value, $scaleBy) `
|
||||||
|
| #LogarithmWithThreshold(eps) => j`verticalLog($value, $scaleBy, epsilon=$eps) `
|
||||||
}
|
}
|
||||||
|
|
||||||
let toIntegralSumCacheFn = x =>
|
let toIntegralSumCacheFn = x =>
|
||||||
switch x {
|
switch x {
|
||||||
| #Multiply => (a, b) => Some(a *. b)
|
| #Multiply => (a, b) => Some(a *. b)
|
||||||
| #Divide => (a, b) => Some(a /. b)
|
| #Divide => (a, b) => Some(a /. b)
|
||||||
| #Power => (_, _) => None
|
| #Power | #Logarithm | #LogarithmWithThreshold(_) => (_, _) => None
|
||||||
| #Logarithm => (_, _) => None
|
|
||||||
}
|
}
|
||||||
|
|
||||||
let toIntegralCacheFn = x =>
|
let toIntegralCacheFn = x =>
|
||||||
|
@ -186,6 +200,7 @@ module Scale = {
|
||||||
| #Divide => (_, _) => None
|
| #Divide => (_, _) => None
|
||||||
| #Power => (_, _) => None
|
| #Power => (_, _) => None
|
||||||
| #Logarithm => (_, _) => None
|
| #Logarithm => (_, _) => None
|
||||||
|
| #LogarithmWithThreshold(_) => (_, _) => None
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user