Completed renaming to KLDivergence

Value: [1e-8 to 1e-4]
This commit is contained in:
Quinn Dougherty 2022-05-04 12:21:30 -04:00
parent 683439c7e5
commit 3fcc82442d
4 changed files with 7 additions and 7 deletions

View File

@ -24,7 +24,7 @@ describe("Scale logarithm", () => {
let meanAnalytical = -.Js.Math.log2(high -. low) /. 2.0 *. (high ** 2.0 -. low ** 2.0) // -. Js.Math.log2(high -. low)
switch meanResult {
| Ok(meanValue) => meanValue->expect->toBeCloseTo(meanAnalytical)
| Error(err) => err->expect->toBe(DistributionTypes.OperationError(NegativeInfinityError))
| Error(err) => err->expect->toEqual(DistributionTypes.OperationError(NegativeInfinityError))
}
})
})

View File

@ -144,7 +144,7 @@ let rec run = (~env, functionCallInfo: functionCallInfo): outputType => {
Dist(dist)
}
| ToDist(Normalize) => dist->GenericDist.normalize->Dist
| ToScore(LogScore(t2)) =>
| ToScore(KLDivergence(t2)) =>
GenericDist.logScore(dist, t2, ~toPointSetFn)
->E.R2.fmap(r => Float(r))
->OutputLocal.fromResult

View File

@ -90,7 +90,7 @@ module DistributionOperation = {
| ToString
| ToSparkline(int)
type toScore = LogScore(genericDist)
type toScore = KLDivergence(genericDist)
type fromDist =
| ToFloat(toFloat)
@ -118,7 +118,7 @@ module DistributionOperation = {
| ToFloat(#Pdf(r)) => `pdf(${E.Float.toFixed(r)})`
| ToFloat(#Sample) => `sample`
| ToFloat(#IntegralSum) => `integralSum`
| ToScore(LogScore(_)) => `logScore`
| ToScore(KLDivergence(_)) => `klDivergence`
| ToDist(Normalize) => `normalize`
| ToDist(ToPointSet) => `toPointSet`
| ToDist(ToSampleSet(r)) => `toSampleSet(${E.I.toString(r)})`
@ -157,7 +157,7 @@ module Constructors = {
let fromSamples = (xs): t => FromSamples(xs)
let truncate = (dist, left, right): t => FromDist(ToDist(Truncate(left, right)), dist)
let inspect = (dist): t => FromDist(ToDist(Inspect), dist)
let logScore = (dist1, dist2): t => FromDist(ToScore(LogScore(dist2)), dist1)
let logScore = (dist1, dist2): t => FromDist(ToScore(KLDivergence(dist2)), dist1)
let scalePower = (dist, n): t => FromDist(ToDist(Scale(#Power, n)), dist)
let scaleLogarithm = (dist, n): t => FromDist(ToDist(Scale(#Logarithm, n)), dist)
let toString = (dist): t => FromDist(ToString(ToString), dist)

View File

@ -210,8 +210,8 @@ let dispatchToGenericOutput = (call: ExpressionValue.functionCall, _environment)
a,
)->Some
| ("normalize", [EvDistribution(dist)]) => Helpers.toDistFn(Normalize, dist)
| ("logScore", [EvDistribution(a), EvDistribution(b)]) =>
Some(runGenericOperation(FromDist(ToScore(LogScore(b)), a)))
| ("klDivergence", [EvDistribution(a), EvDistribution(b)]) =>
Some(runGenericOperation(FromDist(ToScore(KLDivergence(b)), a)))
| ("isNormalized", [EvDistribution(dist)]) => Helpers.toBoolFn(IsNormalized, dist)
| ("toPointSet", [EvDistribution(dist)]) => Helpers.toDistFn(ToPointSet, dist)
| ("scaleLog", [EvDistribution(dist)]) =>