Cleaned up Scoring file: no dispatch yet
				
					
				
			Value: [1e-4 to 6e-2]
This commit is contained in:
		
							parent
							
								
									7b865c95f5
								
							
						
					
					
						commit
						bafcb4f7b8
					
				| 
						 | 
				
			
			@ -1,46 +1,118 @@
 | 
			
		|||
module KLDivergence = {
 | 
			
		||||
  let logFn = Js.Math.log // base e
 | 
			
		||||
  let integrand = (predictionElement: float, answerElement: float): result<
 | 
			
		||||
type t = PointSetDist.pointSetDist
 | 
			
		||||
type continuousShape = PointSetTypes.continuousShape
 | 
			
		||||
type discreteShape = PointSetTypes.discreteShape
 | 
			
		||||
type mixedShape = PointSetTypes.mixedShape
 | 
			
		||||
type scalar = float
 | 
			
		||||
type abstractScoreArgs<'a, 'b> = {estimate: 'a, answer: 'b, prior: option<'a>}
 | 
			
		||||
type scoreArgs =
 | 
			
		||||
  | DistEstimateDistAnswer(abstractScoreArgs<t, t>)
 | 
			
		||||
  | DistEstimateScalarAnswer(abstractScoreArgs<t, scalar>)
 | 
			
		||||
  | ScalarEstimateDistAnswer(abstractScoreArgs<scalar, t>)
 | 
			
		||||
  | ScalarEstimateScalarAnswer(abstractScoreArgs<scalar, scalar>)
 | 
			
		||||
let logFn = Js.Math.log // base e
 | 
			
		||||
let minusScaledLogOfQuot = (~esti, ~answ): result<float, Operation.Error.t> => {
 | 
			
		||||
  let quot = esti /. answ
 | 
			
		||||
  quot < 0.0 ? Error(Operation.ComplexNumberError) : Ok(-.answ *. logFn(quot))
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
module WithDistAnswer = {
 | 
			
		||||
  // The Kullback-Leibler divergence
 | 
			
		||||
  let integrand = (estimateElement: float, answerElement: float): result<
 | 
			
		||||
    float,
 | 
			
		||||
    Operation.Error.t,
 | 
			
		||||
  > =>
 | 
			
		||||
    // We decided that negative infinity, not an error at answerElement = 0.0, is a desirable value.
 | 
			
		||||
    if answerElement == 0.0 {
 | 
			
		||||
      Ok(0.0)
 | 
			
		||||
    } else if predictionElement == 0.0 {
 | 
			
		||||
    } else if estimateElement == 0.0 {
 | 
			
		||||
      Ok(infinity)
 | 
			
		||||
    } else {
 | 
			
		||||
      let quot = predictionElement /. answerElement
 | 
			
		||||
      quot < 0.0 ? Error(Operation.ComplexNumberError) : Ok(-.answerElement *. logFn(quot))
 | 
			
		||||
      minusScaledLogOfQuot(~esti=estimateElement, ~answ=answerElement)
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
  let sum = (~estimate: t, ~answer: t, ~integrateFn) =>
 | 
			
		||||
    PointSetDist.combinePointwise(integrand, estimate, answer)->E.R2.fmap(integrateFn)
 | 
			
		||||
 | 
			
		||||
  let sumWithPrior = (~estimate: t, ~answer: t, ~prior: t, ~integrateFn): result<
 | 
			
		||||
    float,
 | 
			
		||||
    Operation.Error.t,
 | 
			
		||||
  > => {
 | 
			
		||||
    let kl1 = sum(~estimate, ~answer, ~integrateFn)
 | 
			
		||||
    let kl2 = sum(~estimate=prior, ~answer, ~integrateFn)
 | 
			
		||||
    E.R.merge(kl1, kl2)->E.R2.fmap(((k1', k2')) => kl1' -. kl2')
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
module LogScoreWithPointResolution = {
 | 
			
		||||
  let logFn = Js.Math.log
 | 
			
		||||
  let score = (
 | 
			
		||||
    ~priorPdf: option<float => float>,
 | 
			
		||||
    ~predictionPdf: float => float,
 | 
			
		||||
    ~answer: float,
 | 
			
		||||
  ): result<float, Operation.Error.t> => {
 | 
			
		||||
    let numerator = answer->predictionPdf
 | 
			
		||||
    if numerator < 0.0 {
 | 
			
		||||
module WithScalarAnswer = {
 | 
			
		||||
  let score' = (~estimatePdf: float => float, ~answer: float): result<float, Operation.Error.t> => {
 | 
			
		||||
    let density = answer->estimatePdf
 | 
			
		||||
    if density < 0.0 {
 | 
			
		||||
      Operation.PdfInvalidError->Error
 | 
			
		||||
    } else if numerator == 0.0 {
 | 
			
		||||
    } else if density == 0.0 {
 | 
			
		||||
      infinity->Ok
 | 
			
		||||
    } else {
 | 
			
		||||
      -.(
 | 
			
		||||
        switch priorPdf {
 | 
			
		||||
        | None => numerator->logFn
 | 
			
		||||
        | Some(f) => {
 | 
			
		||||
            let priorDensityOfAnswer = f(answer)
 | 
			
		||||
            if priorDensityOfAnswer == 0.0 {
 | 
			
		||||
              neg_infinity
 | 
			
		||||
      density->logFn->(x => -.x)->Ok
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
  let scoreWithPrior' = (
 | 
			
		||||
    ~estimatePdf: float => float,
 | 
			
		||||
    ~answer: float,
 | 
			
		||||
    ~priorPdf: float => float,
 | 
			
		||||
  ): result<float, Operation.Error.t> => {
 | 
			
		||||
    let numerator = answer->estimatePdf
 | 
			
		||||
    let priorDensityOfAnswer = answer->priorPdf
 | 
			
		||||
    if numerator < 0.0 || priorDensityOfAnswer < 0.0 {
 | 
			
		||||
      Operation.PdfInvalidError->Error
 | 
			
		||||
    } else if numerator == 0.0 || priorDensityOfAnswer == 0.0 {
 | 
			
		||||
      infinity->Ok
 | 
			
		||||
    } else {
 | 
			
		||||
              (numerator /. priorDensityOfAnswer)->logFn
 | 
			
		||||
      minusScaledLogOfQuot(~esti=numerator, ~answ=priorDensityOfAnswer)
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
  let score = (~estimate: t, ~answer: t): result<float, Operation.Error.t> => {
 | 
			
		||||
    let estimatePdf = x => XYShape.XtoY.linear(x, estimate.xyShape)
 | 
			
		||||
    score'(~estimatePdf, ~answer)
 | 
			
		||||
  }
 | 
			
		||||
      )->Ok
 | 
			
		||||
    }
 | 
			
		||||
  let scoreWithPrior = (~estimate: t, ~answer: t, ~prior: t): result<float, Operation.Error.t> => {
 | 
			
		||||
    let estimatePdf = x => XYShape.XtoY.linear(x, estimate.xyShape)
 | 
			
		||||
    let priorPdf = x => XYShape.XtoY.linear(x, prior.xyShape)
 | 
			
		||||
    scoreWithPrior'(~estimatePdf, ~answer, ~priorPdf)
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
module TwoScalars = {
 | 
			
		||||
  let score = (~estimate: float, ~answer: float) =>
 | 
			
		||||
    if answer == 0.0 {
 | 
			
		||||
      0.0->Ok
 | 
			
		||||
    } else if estimate == 0.0 {
 | 
			
		||||
      infinity->Ok
 | 
			
		||||
    } else {
 | 
			
		||||
      minusScaledLogOfQuot(~esti=estimate, ~answ=answer)
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
  let scoreWithPrior = (~estimate: float, ~answer: float, ~prior: float) =>
 | 
			
		||||
    if answer == 0.0 {
 | 
			
		||||
      0.0->Ok
 | 
			
		||||
    } else if estimate == 0.0 || prior == 0.0 {
 | 
			
		||||
      infinity->Ok
 | 
			
		||||
    } else {
 | 
			
		||||
      minusScaledLogOfQuot(~esti=estimate /. prior, ~answ=answer)
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
let logScore = (args: scoreArgs, ~integrateFn): result<float, Operation.Error.t> =>
 | 
			
		||||
  switch args {
 | 
			
		||||
  | DistEstimateDistAnswer({estimate, answer, prior: None}) =>
 | 
			
		||||
    WithDistAnswer.sum(~estimate, ~answer, ~integrateFn)
 | 
			
		||||
  | DistEstimateDistAnswer({estimate, answer, prior: Some(prior)}) =>
 | 
			
		||||
    WithDistAnswer.sumWithPrior(~estimate, ~answer, ~prior, ~integrateFn)
 | 
			
		||||
  | DistEstimateScalarAnswer({estimate, answer, prior: None}) =>
 | 
			
		||||
    WithScalarAnswer.score(~estimate, ~answer)
 | 
			
		||||
  | DistEstimateScalarAnswer({estimate, answer, prior: Some(prior)}) =>
 | 
			
		||||
    WithScalarAnswer.scoreWithPrior(~estimate, ~answer, ~prior)
 | 
			
		||||
  | ScalarEstimateDistAnswer(_) => Operation.NotYetImplemented->Error
 | 
			
		||||
  | ScalarEstimateScalarAnswer({estimate, answer, prior: None}) =>
 | 
			
		||||
    TwoScalars.score(~estimate, ~answer)
 | 
			
		||||
  | ScalarEstimateScalarAnswer({estimate, answer, prior: Some(prior)}) =>
 | 
			
		||||
    TwoScalars.scoreWithPrior(~estimate, ~answer, ~prior)
 | 
			
		||||
  }
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue
	
	Block a user