good evening, not working yet, but out of time for the night
Value: [1e-6 to 1e-4]
This commit is contained in:
		
							parent
							
								
									b2d80eef86
								
							
						
					
					
						commit
						ccd55ef8f1
					
				|  | @ -19,7 +19,7 @@ describe("kl divergence on continuous distributions", () => { | |||
|       let analyticalKl = Js.Math.log((highPrediction -. lowPrediction) /. (highAnswer -. lowAnswer)) | ||||
|       let kl = E.R.liftJoin2(klDivergence, prediction, answer) | ||||
|       switch kl { | ||||
|       | Ok(kl') => kl'->expect->toBeCloseTo(analyticalKl) | ||||
|       | Ok(kl') => kl'->expect->toBeSoCloseTo(analyticalKl, ~digits=7) | ||||
|       | Error(err) => { | ||||
|           Js.Console.log(DistributionTypes.Error.toString(err)) | ||||
|           raise(KlFailed) | ||||
|  | @ -51,7 +51,7 @@ describe("kl divergence on continuous distributions", () => { | |||
|     let kl = E.R.liftJoin2(klDivergence, prediction, answer) | ||||
| 
 | ||||
|     switch kl { | ||||
|     | Ok(kl') => kl'->expect->toBeCloseTo(analyticalKl) | ||||
|     | Ok(kl') => kl'->expect->toBeSoCloseTo(analyticalKl, ~digits=3) | ||||
|     | Error(err) => { | ||||
|         Js.Console.log(DistributionTypes.Error.toString(err)) | ||||
|         raise(KlFailed) | ||||
|  | @ -78,9 +78,9 @@ describe("kl divergence on discrete distributions", () => { | |||
|     | (Dist(prediction'), Dist(answer')) => klDivergence(prediction', answer') | ||||
|     | _ => raise(MixtureFailed) | ||||
|     } | ||||
|     let analyticalKl = Js.Math.log(2.0 /. 3.0) | ||||
|     let analyticalKl = Js.Math.log(3.0 /. 2.0) | ||||
|     switch kl { | ||||
|     | Ok(kl') => kl'->expect->toBeCloseTo(analyticalKl) | ||||
|     | Ok(kl') => kl'->expect->toBeSoCloseTo(analyticalKl, ~digits=7) | ||||
|     | Error(err) => | ||||
|       Js.Console.log(DistributionTypes.Error.toString(err)) | ||||
|       raise(KlFailed) | ||||
|  |  | |||
|  | @ -51,6 +51,7 @@ let mkExponential = rate => DistributionTypes.Symbolic(#Exponential({rate: rate} | |||
| let mkUniform = (low, high) => DistributionTypes.Symbolic(#Uniform({low: low, high: high})) | ||||
| let mkCauchy = (local, scale) => DistributionTypes.Symbolic(#Cauchy({local: local, scale: scale})) | ||||
| let mkLognormal = (mu, sigma) => DistributionTypes.Symbolic(#Lognormal({mu: mu, sigma: sigma})) | ||||
| let mkDirac = x => DistributionTypes.Symbolic(#Float(x)) | ||||
| 
 | ||||
| let normalMake = SymbolicDist.Normal.make | ||||
| let betaMake = SymbolicDist.Beta.make | ||||
|  |  | |||
|  | @ -229,11 +229,25 @@ module T = Dist({ | |||
|   } | ||||
| 
 | ||||
|   let klDivergence = (prediction: t, answer: t) => { | ||||
|     combinePointwise( | ||||
|       ~combiner=XYShape.PointwiseCombination.combineAlongSupportOfSecondArgument0, | ||||
|       ~fn=PointSetDist_Scoring.KLDivergence.integrand, | ||||
|       prediction, | ||||
|       answer, | ||||
|     ) |> E.R2.bind(integralEndYResult) | ||||
|     let massOrZero = (t: t, x: float): float => { | ||||
|       let i = E.A.findIndex(x' => x' == x, t.xyShape.xs) | ||||
|       switch i { | ||||
|       | None => 0.0 | ||||
|       | Some(i') => t.xyShape.ys[i'] | ||||
|       } | ||||
|     } | ||||
|     let predictionNewYs = E.A.fmap(massOrZero(answer), prediction.xyShape.xs) | ||||
|     let integrand = XYShape.PointwiseCombination.combine( | ||||
|       PointSetDist_Scoring.KLDivergence.integrand, | ||||
|       XYShape.XtoY.continuousInterpolator(#Stepwise, #UseZero), | ||||
|       {XYShape.xs: answer.xyShape.xs, XYShape.ys: predictionNewYs}, | ||||
|       answer.xyShape, | ||||
|     ) | ||||
|     let xyShapeToDiscrete: XYShape.xyShape => t = xyShape => { | ||||
|       xyShape: xyShape, | ||||
|       integralSumCache: None, | ||||
|       integralCache: None, | ||||
|     } | ||||
|     integrand->E.R2.fmap(x => x->xyShapeToDiscrete->integralEndY) | ||||
|   } | ||||
| }) | ||||
|  |  | |||
|  | @ -199,6 +199,7 @@ module T = Dist({ | |||
|   let klDivergence = (t1: t, t2: t) => | ||||
|     switch (t1, t2) { | ||||
|     | (Continuous(t1), Continuous(t2)) => Continuous.T.klDivergence(t1, t2) | ||||
|     | (Discrete(t1), Discrete(t2)) => Discrete.T.klDivergence(t1, t2) | ||||
|     | _ => Error(NotYetImplemented) | ||||
|     } | ||||
| }) | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue
	
	Block a user