Added logScaleWithThreshold(eps)
and completed renaming to
`klDivergence` Value: [1e-5 to 1e-3]
This commit is contained in:
parent
236be470d5
commit
c95c56cfb8
|
@ -145,7 +145,7 @@ let rec run = (~env, functionCallInfo: functionCallInfo): outputType => {
|
|||
}
|
||||
| ToDist(Normalize) => dist->GenericDist.normalize->Dist
|
||||
| ToScore(KLDivergence(t2)) =>
|
||||
GenericDist.logScore(dist, t2, ~toPointSetFn)
|
||||
GenericDist.klDivergence(dist, t2, ~toPointSetFn)
|
||||
->E.R2.fmap(r => Float(r))
|
||||
->OutputLocal.fromResult
|
||||
| ToBool(IsNormalized) => dist->GenericDist.isNormalized->Bool
|
||||
|
@ -163,6 +163,15 @@ let rec run = (~env, functionCallInfo: functionCallInfo): outputType => {
|
|||
->GenericDist.toPointSet(~xyPointLength, ~sampleCount, ())
|
||||
->E.R2.fmap(r => Dist(PointSet(r)))
|
||||
->OutputLocal.fromResult
|
||||
| ToDist(Scale(#LogarithmWithThreshold(eps), f)) =>
|
||||
dist
|
||||
->GenericDist.pointwiseCombinationFloat(
|
||||
~toPointSetFn,
|
||||
~algebraicCombination=#LogarithmWithThreshold(eps),
|
||||
~f,
|
||||
)
|
||||
->E.R2.fmap(r => Dist(r))
|
||||
->OutputLocal.fromResult
|
||||
| ToDist(Scale(#Logarithm, f)) =>
|
||||
dist
|
||||
->GenericDist.pointwiseCombinationFloat(~toPointSetFn, ~algebraicCombination=#Logarithm, ~f)
|
||||
|
|
|
@ -72,6 +72,7 @@ module DistributionOperation = {
|
|||
type toScaleFn = [
|
||||
| #Power
|
||||
| #Logarithm
|
||||
| #LogarithmWithThreshold(float)
|
||||
]
|
||||
|
||||
type toDist =
|
||||
|
@ -126,6 +127,8 @@ module DistributionOperation = {
|
|||
| ToDist(Inspect) => `inspect`
|
||||
| ToDist(Scale(#Power, r)) => `scalePower(${E.Float.toFixed(r)})`
|
||||
| ToDist(Scale(#Logarithm, r)) => `scaleLog(${E.Float.toFixed(r)})`
|
||||
| ToDist(Scale(#LogarithmWithThreshold(eps), r)) =>
|
||||
`scaleLogWithThreshold(${E.Float.toFixed(r)}, epsilon=${E.Float.toFixed(eps)})`
|
||||
| ToString(ToString) => `toString`
|
||||
| ToString(ToSparkline(n)) => `toSparkline(${E.I.toString(n)})`
|
||||
| ToBool(IsNormalized) => `isNormalized`
|
||||
|
@ -160,6 +163,10 @@ module Constructors = {
|
|||
let logScore = (dist1, dist2): t => FromDist(ToScore(KLDivergence(dist2)), dist1)
|
||||
let scalePower = (dist, n): t => FromDist(ToDist(Scale(#Power, n)), dist)
|
||||
let scaleLogarithm = (dist, n): t => FromDist(ToDist(Scale(#Logarithm, n)), dist)
|
||||
let scaleLogarithmWithThreshold = (dist, n, eps): t => FromDist(
|
||||
ToDist(Scale(#LogarithmWithThreshold(eps), n)),
|
||||
dist,
|
||||
)
|
||||
let toString = (dist): t => FromDist(ToString(ToString), dist)
|
||||
let toSparkline = (dist, n): t => FromDist(ToString(ToSparkline(n)), dist)
|
||||
let algebraicAdd = (dist1, dist2: genericDist): t => FromDist(
|
||||
|
|
|
@ -59,10 +59,10 @@ let integralEndY = (t: t): float =>
|
|||
|
||||
let isNormalized = (t: t): bool => Js.Math.abs_float(integralEndY(t) -. 1.0) < 1e-7
|
||||
|
||||
let logScore = (t1, t2, ~toPointSetFn: toPointSetFn): result<float, error> => {
|
||||
let klDivergence = (t1, t2, ~toPointSetFn: toPointSetFn): result<float, error> => {
|
||||
let pointSets = E.R.merge(toPointSetFn(t1), toPointSetFn(t2))
|
||||
pointSets |> E.R2.bind(((a, b)) =>
|
||||
PointSetDist.T.logScore(a, b)->E.R2.errMap(x => DistributionTypes.OperationError(x))
|
||||
PointSetDist.T.klDivergence(a, b)->E.R2.errMap(x => DistributionTypes.OperationError(x))
|
||||
)
|
||||
}
|
||||
|
||||
|
@ -391,14 +391,12 @@ let pointwiseCombinationFloat = (
|
|||
~algebraicCombination: Operation.algebraicOperation,
|
||||
~f: float,
|
||||
): result<t, error> => {
|
||||
let m = switch algebraicCombination {
|
||||
| #Add | #Subtract => Error(DistributionTypes.DistributionVerticalShiftIsInvalid)
|
||||
| (#Multiply | #Divide | #Power | #Logarithm) as arithmeticOperation =>
|
||||
let executeCombination = arithOp =>
|
||||
toPointSetFn(t)->E.R.bind(t => {
|
||||
//TODO: Move to PointSet codebase
|
||||
let fn = (secondary, main) => Operation.Scale.toFn(arithmeticOperation, main, secondary)
|
||||
let integralSumCacheFn = Operation.Scale.toIntegralSumCacheFn(arithmeticOperation)
|
||||
let integralCacheFn = Operation.Scale.toIntegralCacheFn(arithmeticOperation)
|
||||
let fn = (secondary, main) => Operation.Scale.toFn(arithOp, main, secondary)
|
||||
let integralSumCacheFn = Operation.Scale.toIntegralSumCacheFn(arithOp)
|
||||
let integralCacheFn = Operation.Scale.toIntegralCacheFn(arithOp)
|
||||
PointSetDist.T.mapYResult(
|
||||
~integralSumCacheFn=integralSumCacheFn(f),
|
||||
~integralCacheFn=integralCacheFn(f),
|
||||
|
@ -406,6 +404,11 @@ let pointwiseCombinationFloat = (
|
|||
t,
|
||||
)->E.R2.errMap(x => DistributionTypes.OperationError(x))
|
||||
})
|
||||
let m = switch algebraicCombination {
|
||||
| #Add | #Subtract => Error(DistributionTypes.DistributionVerticalShiftIsInvalid)
|
||||
| (#Multiply | #Divide | #Power | #Logarithm) as arithmeticOperation =>
|
||||
executeCombination(arithmeticOperation)
|
||||
| #LogarithmWithThreshold(eps) => executeCombination(#LogarithmWithThreshold(eps))
|
||||
}
|
||||
m->E.R2.fmap(r => DistributionTypes.PointSet(r))
|
||||
}
|
||||
|
|
|
@ -23,7 +23,7 @@ let toFloatOperation: (
|
|||
~distToFloatOperation: DistributionTypes.DistributionOperation.toFloat,
|
||||
) => result<float, error>
|
||||
|
||||
let logScore: (t, t, ~toPointSetFn: toPointSetFn) => result<float, error>
|
||||
let klDivergence: (t, t, ~toPointSetFn: toPointSetFn) => result<float, error>
|
||||
|
||||
@genType
|
||||
let toPointSet: (
|
||||
|
|
|
@ -270,7 +270,7 @@ module T = Dist({
|
|||
let variance = (t: t): float =>
|
||||
XYShape.Analysis.getVarianceDangerously(t, mean, Analysis.getMeanOfSquares)
|
||||
|
||||
let logScore = (base: t, reference: t) => {
|
||||
let klDivergence = (base: t, reference: t) => {
|
||||
let referenceIsZero = switch Distributions.Common.isZeroEverywhere(
|
||||
PointSetTypes.Continuous(reference),
|
||||
) {
|
||||
|
|
|
@ -229,7 +229,7 @@ module T = Dist({
|
|||
XYShape.Analysis.getVarianceDangerously(t, mean, getMeanOfSquares)
|
||||
}
|
||||
|
||||
let logScore = (base: t, reference: t) => {
|
||||
let klDivergence = (base: t, reference: t) => {
|
||||
let referenceIsZero = switch Distributions.Common.isZeroEverywhere(
|
||||
PointSetTypes.Discrete(reference),
|
||||
) {
|
||||
|
|
|
@ -33,7 +33,7 @@ module type dist = {
|
|||
|
||||
let mean: t => float
|
||||
let variance: t => float
|
||||
let logScore: (t, t) => result<float, Operation.Error.t>
|
||||
let klDivergence: (t, t) => result<float, Operation.Error.t>
|
||||
}
|
||||
|
||||
module Dist = (T: dist) => {
|
||||
|
@ -56,7 +56,7 @@ module Dist = (T: dist) => {
|
|||
let mean = T.mean
|
||||
let variance = T.variance
|
||||
let integralEndY = T.integralEndY
|
||||
let logScore = T.logScore
|
||||
let klDivergence = T.klDivergence
|
||||
|
||||
let updateIntegralCache = T.updateIntegralCache
|
||||
|
||||
|
|
|
@ -301,7 +301,7 @@ module T = Dist({
|
|||
}
|
||||
}
|
||||
|
||||
let logScore = (base: t, reference: t) => {
|
||||
let klDivergence = (base: t, reference: t) => {
|
||||
let referenceIsZero = switch Distributions.Common.isZeroEverywhere(
|
||||
PointSetTypes.Mixed(reference),
|
||||
) {
|
||||
|
|
|
@ -196,15 +196,15 @@ module T = Dist({
|
|||
| Continuous(m) => Continuous.T.variance(m)
|
||||
}
|
||||
|
||||
let logScore = (t1: t, t2: t) =>
|
||||
let klDivergence = (t1: t, t2: t) =>
|
||||
switch (t1, t2) {
|
||||
| (Continuous(t1), Continuous(t2)) => Continuous.T.logScore(t1, t2)
|
||||
| (Discrete(t1), Discrete(t2)) => Discrete.T.logScore(t1, t2)
|
||||
| (Mixed(t1), Mixed(t2)) => Mixed.T.logScore(t1, t2)
|
||||
| (Continuous(t1), Continuous(t2)) => Continuous.T.klDivergence(t1, t2)
|
||||
| (Discrete(t1), Discrete(t2)) => Discrete.T.klDivergence(t1, t2)
|
||||
| (Mixed(t1), Mixed(t2)) => Mixed.T.klDivergence(t1, t2)
|
||||
| _ => {
|
||||
let t1 = toMixed(t1)
|
||||
let t2 = toMixed(t2)
|
||||
Mixed.T.logScore(t1, t2)
|
||||
Mixed.T.klDivergence(t1, t2)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
|
|
@ -222,6 +222,8 @@ let dispatchToGenericOutput = (call: ExpressionValue.functionCall, _environment)
|
|||
| ("scaleLog10", [EvDistribution(dist)]) => Helpers.toDistFn(Scale(#Logarithm, 10.0), dist)
|
||||
| ("scaleLog", [EvDistribution(dist), EvNumber(float)]) =>
|
||||
Helpers.toDistFn(Scale(#Logarithm, float), dist)
|
||||
| ("scaleLogWithThreshold", [EvDistribution(dist), EvNumber(base), EvNumber(eps)]) =>
|
||||
Helpers.toDistFn(Scale(#LogarithmWithThreshold(eps), base), dist)
|
||||
| ("scalePow", [EvDistribution(dist), EvNumber(float)]) =>
|
||||
Helpers.toDistFn(Scale(#Power, float), dist)
|
||||
| ("scaleExp", [EvDistribution(dist)]) =>
|
||||
|
|
|
@ -8,6 +8,7 @@ type algebraicOperation = [
|
|||
| #Divide
|
||||
| #Power
|
||||
| #Logarithm
|
||||
| #LogarithmWithThreshold(float)
|
||||
]
|
||||
|
||||
type convolutionOperation = [
|
||||
|
@ -18,7 +19,7 @@ type convolutionOperation = [
|
|||
|
||||
@genType
|
||||
type pointwiseOperation = [#Add | #Multiply | #Power]
|
||||
type scaleOperation = [#Multiply | #Power | #Logarithm | #Divide]
|
||||
type scaleOperation = [#Multiply | #Power | #Logarithm | #LogarithmWithThreshold(float) | #Divide]
|
||||
type distToFloatOperation = [
|
||||
| #Pdf(float)
|
||||
| #Cdf(float)
|
||||
|
@ -35,7 +36,7 @@ module Convolution = {
|
|||
| #Add => Some(#Add)
|
||||
| #Subtract => Some(#Subtract)
|
||||
| #Multiply => Some(#Multiply)
|
||||
| #Divide | #Power | #Logarithm => None
|
||||
| #Divide | #Power | #Logarithm | #LogarithmWithThreshold(_) => None
|
||||
}
|
||||
|
||||
let canDoAlgebraicOperation = (op: algebraicOperation): bool =>
|
||||
|
@ -108,6 +109,12 @@ module Algebraic = {
|
|||
| #Power => power(a, b)
|
||||
| #Divide => divide(a, b)
|
||||
| #Logarithm => logarithm(a, b)
|
||||
| #LogarithmWithThreshold(eps) =>
|
||||
if a < eps {
|
||||
Ok(0.0)
|
||||
} else {
|
||||
logarithm(a, b)
|
||||
}
|
||||
}
|
||||
|
||||
let toString = x =>
|
||||
|
@ -118,6 +125,7 @@ module Algebraic = {
|
|||
| #Power => "**"
|
||||
| #Divide => "/"
|
||||
| #Logarithm => "log"
|
||||
| #LogarithmWithThreshold(_) => "log"
|
||||
}
|
||||
|
||||
let format = (a, b, c) => b ++ (" " ++ (toString(a) ++ (" " ++ c)))
|
||||
|
@ -162,6 +170,12 @@ module Scale = {
|
|||
} else {
|
||||
logarithm(a, b)
|
||||
}
|
||||
| #LogarithmWithThreshold(eps) =>
|
||||
if a < eps {
|
||||
Ok(0.0)
|
||||
} else {
|
||||
logarithm(a, b)
|
||||
}
|
||||
}
|
||||
|
||||
let format = (operation: t, value, scaleBy) =>
|
||||
|
@ -170,14 +184,14 @@ module Scale = {
|
|||
| #Divide => j`verticalDivide($value, $scaleBy) `
|
||||
| #Power => j`verticalPower($value, $scaleBy) `
|
||||
| #Logarithm => j`verticalLog($value, $scaleBy) `
|
||||
| #LogarithmWithThreshold(eps) => j`verticalLog($value, $scaleBy, epsilon=$eps) `
|
||||
}
|
||||
|
||||
let toIntegralSumCacheFn = x =>
|
||||
switch x {
|
||||
| #Multiply => (a, b) => Some(a *. b)
|
||||
| #Divide => (a, b) => Some(a /. b)
|
||||
| #Power => (_, _) => None
|
||||
| #Logarithm => (_, _) => None
|
||||
| #Power | #Logarithm | #LogarithmWithThreshold(_) => (_, _) => None
|
||||
}
|
||||
|
||||
let toIntegralCacheFn = x =>
|
||||
|
@ -186,6 +200,7 @@ module Scale = {
|
|||
| #Divide => (_, _) => None
|
||||
| #Power => (_, _) => None
|
||||
| #Logarithm => (_, _) => None
|
||||
| #LogarithmWithThreshold(_) => (_, _) => None
|
||||
}
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user