Completed renaming to KLDivergence

Value: [1e-8 to 1e-4]
This commit is contained in:
Quinn Dougherty 2022-05-04 12:21:30 -04:00
parent 683439c7e5
commit 3fcc82442d
4 changed files with 7 additions and 7 deletions

View File

@ -24,7 +24,7 @@ describe("Scale logarithm", () => {
let meanAnalytical = -.Js.Math.log2(high -. low) /. 2.0 *. (high ** 2.0 -. low ** 2.0) // -. Js.Math.log2(high -. low) let meanAnalytical = -.Js.Math.log2(high -. low) /. 2.0 *. (high ** 2.0 -. low ** 2.0) // -. Js.Math.log2(high -. low)
switch meanResult { switch meanResult {
| Ok(meanValue) => meanValue->expect->toBeCloseTo(meanAnalytical) | Ok(meanValue) => meanValue->expect->toBeCloseTo(meanAnalytical)
| Error(err) => err->expect->toBe(DistributionTypes.OperationError(NegativeInfinityError)) | Error(err) => err->expect->toEqual(DistributionTypes.OperationError(NegativeInfinityError))
} }
}) })
}) })

View File

@ -144,7 +144,7 @@ let rec run = (~env, functionCallInfo: functionCallInfo): outputType => {
Dist(dist) Dist(dist)
} }
| ToDist(Normalize) => dist->GenericDist.normalize->Dist | ToDist(Normalize) => dist->GenericDist.normalize->Dist
| ToScore(LogScore(t2)) => | ToScore(KLDivergence(t2)) =>
GenericDist.logScore(dist, t2, ~toPointSetFn) GenericDist.logScore(dist, t2, ~toPointSetFn)
->E.R2.fmap(r => Float(r)) ->E.R2.fmap(r => Float(r))
->OutputLocal.fromResult ->OutputLocal.fromResult

View File

@ -90,7 +90,7 @@ module DistributionOperation = {
| ToString | ToString
| ToSparkline(int) | ToSparkline(int)
type toScore = LogScore(genericDist) type toScore = KLDivergence(genericDist)
type fromDist = type fromDist =
| ToFloat(toFloat) | ToFloat(toFloat)
@ -118,7 +118,7 @@ module DistributionOperation = {
| ToFloat(#Pdf(r)) => `pdf(${E.Float.toFixed(r)})` | ToFloat(#Pdf(r)) => `pdf(${E.Float.toFixed(r)})`
| ToFloat(#Sample) => `sample` | ToFloat(#Sample) => `sample`
| ToFloat(#IntegralSum) => `integralSum` | ToFloat(#IntegralSum) => `integralSum`
| ToScore(LogScore(_)) => `logScore` | ToScore(KLDivergence(_)) => `klDivergence`
| ToDist(Normalize) => `normalize` | ToDist(Normalize) => `normalize`
| ToDist(ToPointSet) => `toPointSet` | ToDist(ToPointSet) => `toPointSet`
| ToDist(ToSampleSet(r)) => `toSampleSet(${E.I.toString(r)})` | ToDist(ToSampleSet(r)) => `toSampleSet(${E.I.toString(r)})`
@ -157,7 +157,7 @@ module Constructors = {
let fromSamples = (xs): t => FromSamples(xs) let fromSamples = (xs): t => FromSamples(xs)
let truncate = (dist, left, right): t => FromDist(ToDist(Truncate(left, right)), dist) let truncate = (dist, left, right): t => FromDist(ToDist(Truncate(left, right)), dist)
let inspect = (dist): t => FromDist(ToDist(Inspect), dist) let inspect = (dist): t => FromDist(ToDist(Inspect), dist)
let logScore = (dist1, dist2): t => FromDist(ToScore(LogScore(dist2)), dist1) let logScore = (dist1, dist2): t => FromDist(ToScore(KLDivergence(dist2)), dist1)
let scalePower = (dist, n): t => FromDist(ToDist(Scale(#Power, n)), dist) let scalePower = (dist, n): t => FromDist(ToDist(Scale(#Power, n)), dist)
let scaleLogarithm = (dist, n): t => FromDist(ToDist(Scale(#Logarithm, n)), dist) let scaleLogarithm = (dist, n): t => FromDist(ToDist(Scale(#Logarithm, n)), dist)
let toString = (dist): t => FromDist(ToString(ToString), dist) let toString = (dist): t => FromDist(ToString(ToString), dist)

View File

@ -210,8 +210,8 @@ let dispatchToGenericOutput = (call: ExpressionValue.functionCall, _environment)
a, a,
)->Some )->Some
| ("normalize", [EvDistribution(dist)]) => Helpers.toDistFn(Normalize, dist) | ("normalize", [EvDistribution(dist)]) => Helpers.toDistFn(Normalize, dist)
| ("logScore", [EvDistribution(a), EvDistribution(b)]) => | ("klDivergence", [EvDistribution(a), EvDistribution(b)]) =>
Some(runGenericOperation(FromDist(ToScore(LogScore(b)), a))) Some(runGenericOperation(FromDist(ToScore(KLDivergence(b)), a)))
| ("isNormalized", [EvDistribution(dist)]) => Helpers.toBoolFn(IsNormalized, dist) | ("isNormalized", [EvDistribution(dist)]) => Helpers.toBoolFn(IsNormalized, dist)
| ("toPointSet", [EvDistribution(dist)]) => Helpers.toDistFn(ToPointSet, dist) | ("toPointSet", [EvDistribution(dist)]) => Helpers.toDistFn(ToPointSet, dist)
| ("scaleLog", [EvDistribution(dist)]) => | ("scaleLog", [EvDistribution(dist)]) =>