good evening, not working yet, but out of time for the night
Value: [1e-6 to 1e-4]
This commit is contained in:
		
							parent
							
								
									b2d80eef86
								
							
						
					
					
						commit
						ccd55ef8f1
					
				|  | @ -19,7 +19,7 @@ describe("kl divergence on continuous distributions", () => { | ||||||
|       let analyticalKl = Js.Math.log((highPrediction -. lowPrediction) /. (highAnswer -. lowAnswer)) |       let analyticalKl = Js.Math.log((highPrediction -. lowPrediction) /. (highAnswer -. lowAnswer)) | ||||||
|       let kl = E.R.liftJoin2(klDivergence, prediction, answer) |       let kl = E.R.liftJoin2(klDivergence, prediction, answer) | ||||||
|       switch kl { |       switch kl { | ||||||
|       | Ok(kl') => kl'->expect->toBeCloseTo(analyticalKl) |       | Ok(kl') => kl'->expect->toBeSoCloseTo(analyticalKl, ~digits=7) | ||||||
|       | Error(err) => { |       | Error(err) => { | ||||||
|           Js.Console.log(DistributionTypes.Error.toString(err)) |           Js.Console.log(DistributionTypes.Error.toString(err)) | ||||||
|           raise(KlFailed) |           raise(KlFailed) | ||||||
|  | @ -51,7 +51,7 @@ describe("kl divergence on continuous distributions", () => { | ||||||
|     let kl = E.R.liftJoin2(klDivergence, prediction, answer) |     let kl = E.R.liftJoin2(klDivergence, prediction, answer) | ||||||
| 
 | 
 | ||||||
|     switch kl { |     switch kl { | ||||||
|     | Ok(kl') => kl'->expect->toBeCloseTo(analyticalKl) |     | Ok(kl') => kl'->expect->toBeSoCloseTo(analyticalKl, ~digits=3) | ||||||
|     | Error(err) => { |     | Error(err) => { | ||||||
|         Js.Console.log(DistributionTypes.Error.toString(err)) |         Js.Console.log(DistributionTypes.Error.toString(err)) | ||||||
|         raise(KlFailed) |         raise(KlFailed) | ||||||
|  | @ -78,9 +78,9 @@ describe("kl divergence on discrete distributions", () => { | ||||||
|     | (Dist(prediction'), Dist(answer')) => klDivergence(prediction', answer') |     | (Dist(prediction'), Dist(answer')) => klDivergence(prediction', answer') | ||||||
|     | _ => raise(MixtureFailed) |     | _ => raise(MixtureFailed) | ||||||
|     } |     } | ||||||
|     let analyticalKl = Js.Math.log(2.0 /. 3.0) |     let analyticalKl = Js.Math.log(3.0 /. 2.0) | ||||||
|     switch kl { |     switch kl { | ||||||
|     | Ok(kl') => kl'->expect->toBeCloseTo(analyticalKl) |     | Ok(kl') => kl'->expect->toBeSoCloseTo(analyticalKl, ~digits=7) | ||||||
|     | Error(err) => |     | Error(err) => | ||||||
|       Js.Console.log(DistributionTypes.Error.toString(err)) |       Js.Console.log(DistributionTypes.Error.toString(err)) | ||||||
|       raise(KlFailed) |       raise(KlFailed) | ||||||
|  |  | ||||||
|  | @ -51,6 +51,7 @@ let mkExponential = rate => DistributionTypes.Symbolic(#Exponential({rate: rate} | ||||||
| let mkUniform = (low, high) => DistributionTypes.Symbolic(#Uniform({low: low, high: high})) | let mkUniform = (low, high) => DistributionTypes.Symbolic(#Uniform({low: low, high: high})) | ||||||
| let mkCauchy = (local, scale) => DistributionTypes.Symbolic(#Cauchy({local: local, scale: scale})) | let mkCauchy = (local, scale) => DistributionTypes.Symbolic(#Cauchy({local: local, scale: scale})) | ||||||
| let mkLognormal = (mu, sigma) => DistributionTypes.Symbolic(#Lognormal({mu: mu, sigma: sigma})) | let mkLognormal = (mu, sigma) => DistributionTypes.Symbolic(#Lognormal({mu: mu, sigma: sigma})) | ||||||
|  | let mkDirac = x => DistributionTypes.Symbolic(#Float(x)) | ||||||
| 
 | 
 | ||||||
| let normalMake = SymbolicDist.Normal.make | let normalMake = SymbolicDist.Normal.make | ||||||
| let betaMake = SymbolicDist.Beta.make | let betaMake = SymbolicDist.Beta.make | ||||||
|  |  | ||||||
|  | @ -229,11 +229,25 @@ module T = Dist({ | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
|   let klDivergence = (prediction: t, answer: t) => { |   let klDivergence = (prediction: t, answer: t) => { | ||||||
|     combinePointwise( |     let massOrZero = (t: t, x: float): float => { | ||||||
|       ~combiner=XYShape.PointwiseCombination.combineAlongSupportOfSecondArgument0, |       let i = E.A.findIndex(x' => x' == x, t.xyShape.xs) | ||||||
|       ~fn=PointSetDist_Scoring.KLDivergence.integrand, |       switch i { | ||||||
|       prediction, |       | None => 0.0 | ||||||
|       answer, |       | Some(i') => t.xyShape.ys[i'] | ||||||
|     ) |> E.R2.bind(integralEndYResult) |       } | ||||||
|  |     } | ||||||
|  |     let predictionNewYs = E.A.fmap(massOrZero(answer), prediction.xyShape.xs) | ||||||
|  |     let integrand = XYShape.PointwiseCombination.combine( | ||||||
|  |       PointSetDist_Scoring.KLDivergence.integrand, | ||||||
|  |       XYShape.XtoY.continuousInterpolator(#Stepwise, #UseZero), | ||||||
|  |       {XYShape.xs: answer.xyShape.xs, XYShape.ys: predictionNewYs}, | ||||||
|  |       answer.xyShape, | ||||||
|  |     ) | ||||||
|  |     let xyShapeToDiscrete: XYShape.xyShape => t = xyShape => { | ||||||
|  |       xyShape: xyShape, | ||||||
|  |       integralSumCache: None, | ||||||
|  |       integralCache: None, | ||||||
|  |     } | ||||||
|  |     integrand->E.R2.fmap(x => x->xyShapeToDiscrete->integralEndY) | ||||||
|   } |   } | ||||||
| }) | }) | ||||||
|  |  | ||||||
|  | @ -199,6 +199,7 @@ module T = Dist({ | ||||||
|   let klDivergence = (t1: t, t2: t) => |   let klDivergence = (t1: t, t2: t) => | ||||||
|     switch (t1, t2) { |     switch (t1, t2) { | ||||||
|     | (Continuous(t1), Continuous(t2)) => Continuous.T.klDivergence(t1, t2) |     | (Continuous(t1), Continuous(t2)) => Continuous.T.klDivergence(t1, t2) | ||||||
|  |     | (Discrete(t1), Discrete(t2)) => Discrete.T.klDivergence(t1, t2) | ||||||
|     | _ => Error(NotYetImplemented) |     | _ => Error(NotYetImplemented) | ||||||
|     } |     } | ||||||
| }) | }) | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue
	
	Block a user