Merge pull request #512 from quantified-uncertainty/kldivergence-mixed
`klDivergence` on mixed distributions
This commit is contained in:
		
						commit
						937458cd05
					
				|  | @ -11,6 +11,7 @@ let triangularDist: DistributionTypes.genericDist = Symbolic( | |||
| ) | ||||
| let exponentialDist: DistributionTypes.genericDist = Symbolic(#Exponential({rate: 2.0})) | ||||
| let uniformDist: DistributionTypes.genericDist = Symbolic(#Uniform({low: 9.0, high: 10.0})) | ||||
| let uniformDist2: DistributionTypes.genericDist = Symbolic(#Uniform({low: 8.0, high: 11.0})) | ||||
| let floatDist: DistributionTypes.genericDist = Symbolic(#Float(1e1)) | ||||
| 
 | ||||
| exception KlFailed | ||||
|  |  | |||
|  | @ -3,9 +3,15 @@ open Expect | |||
| open TestHelpers | ||||
| open GenericDist_Fixtures | ||||
| 
 | ||||
| // integral from low to high of 1 / (high - low) log(normal(mean, stdev)(x) / (1 / (high - low))) dx | ||||
| let klNormalUniform = (mean, stdev, low, high): float => | ||||
|   -.Js.Math.log((high -. low) /. Js.Math.sqrt(2.0 *. MagicNumbers.Math.pi *. stdev ** 2.0)) +. | ||||
|   1.0 /. | ||||
|   stdev ** 2.0 *. | ||||
|   (mean ** 2.0 -. (high +. low) *. mean +. (low ** 2.0 +. high *. low +. high ** 2.0) /. 3.0) | ||||
| 
 | ||||
| describe("klDivergence: continuous -> continuous -> float", () => { | ||||
|   let klDivergence = DistributionOperation.Constructors.klDivergence(~env) | ||||
|   exception KlFailed | ||||
| 
 | ||||
|   let testUniform = (lowAnswer, highAnswer, lowPrediction, highPrediction) => { | ||||
|     test("of two uniforms is equal to the analytic expression", () => { | ||||
|  | @ -59,6 +65,20 @@ describe("klDivergence: continuous -> continuous -> float", () => { | |||
|       } | ||||
|     } | ||||
|   }) | ||||
| 
 | ||||
|   test("of a normal and a uniform is equal to the formula", () => { | ||||
|     let prediction = normalDist10 | ||||
|     let answer = uniformDist | ||||
|     let kl = klDivergence(prediction, answer) | ||||
|     let analyticalKl = klNormalUniform(10.0, 2.0, 9.0, 10.0) | ||||
|     switch kl { | ||||
|     | Ok(kl') => kl'->expect->toBeSoCloseTo(analyticalKl, ~digits=1) | ||||
|     | Error(err) => { | ||||
|         Js.Console.log(DistributionTypes.Error.toString(err)) | ||||
|         raise(KlFailed) | ||||
|       } | ||||
|     } | ||||
|   }) | ||||
| }) | ||||
| 
 | ||||
| describe("klDivergence: discrete -> discrete -> float", () => { | ||||
|  | @ -96,6 +116,64 @@ describe("klDivergence: discrete -> discrete -> float", () => { | |||
|   }) | ||||
| }) | ||||
| 
 | ||||
| describe("klDivergence: mixed -> mixed -> float", () => { | ||||
|   let klDivergence = DistributionOperation.Constructors.klDivergence(~env) | ||||
|   let mixture' = a => DistributionTypes.DistributionOperation.Mixture(a) | ||||
|   let mixture = a => { | ||||
|     let dist' = a->mixture'->run | ||||
|     switch dist' { | ||||
|     | Dist(dist) => dist | ||||
|     | _ => raise(MixtureFailed) | ||||
|     } | ||||
|   } | ||||
|   let a = [(point1, 1.0), (uniformDist, 1.0)]->mixture | ||||
|   let b = [(point1, 1.0), (floatDist, 1.0), (normalDist10, 1.0)]->mixture | ||||
|   let c = [(point1, 1.0), (point2, 1.0), (point3, 1.0), (uniformDist, 1.0)]->mixture | ||||
|   let d = | ||||
|     [(point1, 1.0), (point2, 1.0), (point3, 1.0), (floatDist, 1.0), (uniformDist2, 1.0)]->mixture | ||||
| 
 | ||||
|   test("finite klDivergence produces correct answer", () => { | ||||
|     let prediction = b | ||||
|     let answer = a | ||||
|     let kl = klDivergence(prediction, answer) | ||||
|     // high = 10; low = 9; mean = 10; stdev = 2 | ||||
|     let analyticalKlContinuousPart = klNormalUniform(10.0, 2.0, 9.0, 10.0) /. 2.0 | ||||
|     let analyticalKlDiscretePart = 1.0 /. 2.0 *. Js.Math.log(2.0 /. 1.0) | ||||
|     switch kl { | ||||
|     | Ok(kl') => | ||||
|       kl'->expect->toBeSoCloseTo(analyticalKlContinuousPart +. analyticalKlDiscretePart, ~digits=1) | ||||
|     | Error(err) => | ||||
|       Js.Console.log(DistributionTypes.Error.toString(err)) | ||||
|       raise(KlFailed) | ||||
|     } | ||||
|   }) | ||||
|   test("returns infinity when infinite", () => { | ||||
|     let prediction = a | ||||
|     let answer = b | ||||
|     let kl = klDivergence(prediction, answer) | ||||
|     switch kl { | ||||
|     | Ok(kl') => kl'->expect->toEqual(infinity) | ||||
|     | Error(err) => | ||||
|       Js.Console.log(DistributionTypes.Error.toString(err)) | ||||
|       raise(KlFailed) | ||||
|     } | ||||
|   }) | ||||
|   test("finite klDivergence produces correct answer", () => { | ||||
|     let prediction = d | ||||
|     let answer = c | ||||
|     let kl = klDivergence(prediction, answer) | ||||
|     let analyticalKlContinuousPart = Js.Math.log((11.0 -. 8.0) /. (10.0 -. 9.0)) /. 4.0 // 4 = length of c' array | ||||
|     let analyticalKlDiscretePart = 3.0 /. 4.0 *. Js.Math.log(4.0 /. 3.0) | ||||
|     switch kl { | ||||
|     | Ok(kl') => | ||||
|       kl'->expect->toBeSoCloseTo(analyticalKlContinuousPart +. analyticalKlDiscretePart, ~digits=1) | ||||
|     | Error(err) => | ||||
|       Js.Console.log(DistributionTypes.Error.toString(err)) | ||||
|       raise(KlFailed) | ||||
|     } | ||||
|   }) | ||||
| }) | ||||
| 
 | ||||
| describe("combineAlongSupportOfSecondArgument0", () => { | ||||
|   // This tests the version of the function that we're NOT using. Haven't deleted the test in case we use the code later. | ||||
|   test("test on two uniforms", _ => { | ||||
|  |  | |||
|  | @ -302,10 +302,9 @@ module T = Dist({ | |||
|   } | ||||
| 
 | ||||
|   let klDivergence = (prediction: t, answer: t) => { | ||||
|     Error(Operation.NotYetImplemented) | ||||
|     //    combinePointwise(PointSetDist_Scoring.KLDivergence.integrand, prediction, answer) |> E.R.fmap( | ||||
|     //     integralEndY, | ||||
|     //   ) | ||||
|     let klDiscretePart = Discrete.T.klDivergence(prediction.discrete, answer.discrete) | ||||
|     let klContinuousPart = Continuous.T.klDivergence(prediction.continuous, answer.continuous) | ||||
|     E.R.merge(klDiscretePart, klContinuousPart)->E.R2.fmap(t => fst(t) +. snd(t)) | ||||
|   } | ||||
| }) | ||||
| 
 | ||||
|  |  | |||
|  | @ -200,6 +200,7 @@ module T = Dist({ | |||
|     switch (t1, t2) { | ||||
|     | (Continuous(t1), Continuous(t2)) => Continuous.T.klDivergence(t1, t2) | ||||
|     | (Discrete(t1), Discrete(t2)) => Discrete.T.klDivergence(t1, t2) | ||||
|     | (Mixed(t1), Mixed(t2)) => Mixed.T.klDivergence(t1, t2) | ||||
|     | _ => Error(NotYetImplemented) | ||||
|     } | ||||
| }) | ||||
|  |  | |||
|  | @ -453,6 +453,44 @@ module PointwiseCombination = { | |||
|     T.filterOkYs(newXs, newYs)->Ok | ||||
|   } | ||||
| 
 | ||||
|   /* *Dead code*: Nuño wrote this function to try to increase precision, but it didn't work. | ||||
|      If another traveler comes through with a similar idea, we hope this implementation will help them. | ||||
|      By "enrich" we mean to increase granularity. | ||||
|  */ | ||||
|   let enrichXyShape = (t: T.t): T.t => { | ||||
|     let defaultEnrichmentFactor = 10 | ||||
|     let length = E.A.length(t.xs) | ||||
|     let points = | ||||
|       length < MagicNumbers.Environment.defaultXYPointLength | ||||
|         ? defaultEnrichmentFactor * MagicNumbers.Environment.defaultXYPointLength / length | ||||
|         : defaultEnrichmentFactor | ||||
| 
 | ||||
|     let getInBetween = (x1: float, x2: float): array<float> => { | ||||
|       if abs_float(x1 -. x2) < 2.0 *. MagicNumbers.Epsilon.seven { | ||||
|         [x1] | ||||
|       } else { | ||||
|         let newPointsArray = Belt.Array.makeBy(points - 1, i => i) | ||||
|         // don't repeat the x2 point, it will be gotten in the next iteration. | ||||
|         let result = Js.Array.mapi((pos, i) => | ||||
|           if i == 0 { | ||||
|             x1 | ||||
|           } else { | ||||
|             let points' = Belt.Float.fromInt(points) | ||||
|             let pos' = Belt.Float.fromInt(pos) | ||||
|             x1 *. (points' -. pos') /. points' +. x2 *. pos' /. points' | ||||
|           } | ||||
|         , newPointsArray) | ||||
|         result | ||||
|       } | ||||
|     } | ||||
|     let newXsUnflattened = Js.Array.mapi( | ||||
|       (x, i) => i < length - 2 ? getInBetween(x, t.xs[i + 1]) : [x], | ||||
|       t.xs, | ||||
|     ) | ||||
|     let newXs = Belt.Array.concatMany(newXsUnflattened) | ||||
|     let newYs = E.A.fmap(x => XtoY.linear(x, t), newXs) | ||||
|     {xs: newXs, ys: newYs} | ||||
|   } | ||||
|   // This function is used for klDivergence | ||||
|   let combineAlongSupportOfSecondArgument: ( | ||||
|     (float, float) => result<float, Operation.Error.t>, | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue
	
	Block a user