Cleaned up Scoring file: no dispatch yet
				
					
				
			Value: [1e-4 to 6e-2]
This commit is contained in:
		
							parent
							
								
									7b865c95f5
								
							
						
					
					
						commit
						bafcb4f7b8
					
				| 
						 | 
					@ -1,46 +1,118 @@
 | 
				
			||||||
module KLDivergence = {
 | 
					type t = PointSetDist.pointSetDist
 | 
				
			||||||
  let logFn = Js.Math.log // base e
 | 
					type continuousShape = PointSetTypes.continuousShape
 | 
				
			||||||
  let integrand = (predictionElement: float, answerElement: float): result<
 | 
					type discreteShape = PointSetTypes.discreteShape
 | 
				
			||||||
 | 
					type mixedShape = PointSetTypes.mixedShape
 | 
				
			||||||
 | 
					type scalar = float
 | 
				
			||||||
 | 
					type abstractScoreArgs<'a, 'b> = {estimate: 'a, answer: 'b, prior: option<'a>}
 | 
				
			||||||
 | 
					type scoreArgs =
 | 
				
			||||||
 | 
					  | DistEstimateDistAnswer(abstractScoreArgs<t, t>)
 | 
				
			||||||
 | 
					  | DistEstimateScalarAnswer(abstractScoreArgs<t, scalar>)
 | 
				
			||||||
 | 
					  | ScalarEstimateDistAnswer(abstractScoreArgs<scalar, t>)
 | 
				
			||||||
 | 
					  | ScalarEstimateScalarAnswer(abstractScoreArgs<scalar, scalar>)
 | 
				
			||||||
 | 
					let logFn = Js.Math.log // base e
 | 
				
			||||||
 | 
					let minusScaledLogOfQuot = (~esti, ~answ): result<float, Operation.Error.t> => {
 | 
				
			||||||
 | 
					  let quot = esti /. answ
 | 
				
			||||||
 | 
					  quot < 0.0 ? Error(Operation.ComplexNumberError) : Ok(-.answ *. logFn(quot))
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					module WithDistAnswer = {
 | 
				
			||||||
 | 
					  // The Kullback-Leibler divergence
 | 
				
			||||||
 | 
					  let integrand = (estimateElement: float, answerElement: float): result<
 | 
				
			||||||
    float,
 | 
					    float,
 | 
				
			||||||
    Operation.Error.t,
 | 
					    Operation.Error.t,
 | 
				
			||||||
  > =>
 | 
					  > =>
 | 
				
			||||||
    // We decided that negative infinity, not an error at answerElement = 0.0, is a desirable value.
 | 
					    // We decided that negative infinity, not an error at answerElement = 0.0, is a desirable value.
 | 
				
			||||||
    if answerElement == 0.0 {
 | 
					    if answerElement == 0.0 {
 | 
				
			||||||
      Ok(0.0)
 | 
					      Ok(0.0)
 | 
				
			||||||
    } else if predictionElement == 0.0 {
 | 
					    } else if estimateElement == 0.0 {
 | 
				
			||||||
      Ok(infinity)
 | 
					      Ok(infinity)
 | 
				
			||||||
    } else {
 | 
					    } else {
 | 
				
			||||||
      let quot = predictionElement /. answerElement
 | 
					      minusScaledLogOfQuot(~esti=estimateElement, ~answ=answerElement)
 | 
				
			||||||
      quot < 0.0 ? Error(Operation.ComplexNumberError) : Ok(-.answerElement *. logFn(quot))
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  let sum = (~estimate: t, ~answer: t, ~integrateFn) =>
 | 
				
			||||||
 | 
					    PointSetDist.combinePointwise(integrand, estimate, answer)->E.R2.fmap(integrateFn)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  let sumWithPrior = (~estimate: t, ~answer: t, ~prior: t, ~integrateFn): result<
 | 
				
			||||||
 | 
					    float,
 | 
				
			||||||
 | 
					    Operation.Error.t,
 | 
				
			||||||
 | 
					  > => {
 | 
				
			||||||
 | 
					    let kl1 = sum(~estimate, ~answer, ~integrateFn)
 | 
				
			||||||
 | 
					    let kl2 = sum(~estimate=prior, ~answer, ~integrateFn)
 | 
				
			||||||
 | 
					    E.R.merge(kl1, kl2)->E.R2.fmap(((k1', k2')) => kl1' -. kl2')
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					module WithScalarAnswer = {
 | 
				
			||||||
 | 
					  let score' = (~estimatePdf: float => float, ~answer: float): result<float, Operation.Error.t> => {
 | 
				
			||||||
 | 
					    let density = answer->estimatePdf
 | 
				
			||||||
 | 
					    if density < 0.0 {
 | 
				
			||||||
 | 
					      Operation.PdfInvalidError->Error
 | 
				
			||||||
 | 
					    } else if density == 0.0 {
 | 
				
			||||||
 | 
					      infinity->Ok
 | 
				
			||||||
 | 
					    } else {
 | 
				
			||||||
 | 
					      density->logFn->(x => -.x)->Ok
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					  let scoreWithPrior' = (
 | 
				
			||||||
 | 
					    ~estimatePdf: float => float,
 | 
				
			||||||
 | 
					    ~answer: float,
 | 
				
			||||||
 | 
					    ~priorPdf: float => float,
 | 
				
			||||||
 | 
					  ): result<float, Operation.Error.t> => {
 | 
				
			||||||
 | 
					    let numerator = answer->estimatePdf
 | 
				
			||||||
 | 
					    let priorDensityOfAnswer = answer->priorPdf
 | 
				
			||||||
 | 
					    if numerator < 0.0 || priorDensityOfAnswer < 0.0 {
 | 
				
			||||||
 | 
					      Operation.PdfInvalidError->Error
 | 
				
			||||||
 | 
					    } else if numerator == 0.0 || priorDensityOfAnswer == 0.0 {
 | 
				
			||||||
 | 
					      infinity->Ok
 | 
				
			||||||
 | 
					    } else {
 | 
				
			||||||
 | 
					      minusScaledLogOfQuot(~esti=numerator, ~answ=priorDensityOfAnswer)
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					  let score = (~estimate: t, ~answer: t): result<float, Operation.Error.t> => {
 | 
				
			||||||
 | 
					    let estimatePdf = x => XYShape.XtoY.linear(x, estimate.xyShape)
 | 
				
			||||||
 | 
					    score'(~estimatePdf, ~answer)
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					  let scoreWithPrior = (~estimate: t, ~answer: t, ~prior: t): result<float, Operation.Error.t> => {
 | 
				
			||||||
 | 
					    let estimatePdf = x => XYShape.XtoY.linear(x, estimate.xyShape)
 | 
				
			||||||
 | 
					    let priorPdf = x => XYShape.XtoY.linear(x, prior.xyShape)
 | 
				
			||||||
 | 
					    scoreWithPrior'(~estimatePdf, ~answer, ~priorPdf)
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					module TwoScalars = {
 | 
				
			||||||
 | 
					  let score = (~estimate: float, ~answer: float) =>
 | 
				
			||||||
 | 
					    if answer == 0.0 {
 | 
				
			||||||
 | 
					      0.0->Ok
 | 
				
			||||||
 | 
					    } else if estimate == 0.0 {
 | 
				
			||||||
 | 
					      infinity->Ok
 | 
				
			||||||
 | 
					    } else {
 | 
				
			||||||
 | 
					      minusScaledLogOfQuot(~esti=estimate, ~answ=answer)
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  let scoreWithPrior = (~estimate: float, ~answer: float, ~prior: float) =>
 | 
				
			||||||
 | 
					    if answer == 0.0 {
 | 
				
			||||||
 | 
					      0.0->Ok
 | 
				
			||||||
 | 
					    } else if estimate == 0.0 || prior == 0.0 {
 | 
				
			||||||
 | 
					      infinity->Ok
 | 
				
			||||||
 | 
					    } else {
 | 
				
			||||||
 | 
					      minusScaledLogOfQuot(~esti=estimate /. prior, ~answ=answer)
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
module LogScoreWithPointResolution = {
 | 
					let logScore = (args: scoreArgs, ~integrateFn): result<float, Operation.Error.t> =>
 | 
				
			||||||
  let logFn = Js.Math.log
 | 
					  switch args {
 | 
				
			||||||
  let score = (
 | 
					  | DistEstimateDistAnswer({estimate, answer, prior: None}) =>
 | 
				
			||||||
    ~priorPdf: option<float => float>,
 | 
					    WithDistAnswer.sum(~estimate, ~answer, ~integrateFn)
 | 
				
			||||||
    ~predictionPdf: float => float,
 | 
					  | DistEstimateDistAnswer({estimate, answer, prior: Some(prior)}) =>
 | 
				
			||||||
    ~answer: float,
 | 
					    WithDistAnswer.sumWithPrior(~estimate, ~answer, ~prior, ~integrateFn)
 | 
				
			||||||
  ): result<float, Operation.Error.t> => {
 | 
					  | DistEstimateScalarAnswer({estimate, answer, prior: None}) =>
 | 
				
			||||||
    let numerator = answer->predictionPdf
 | 
					    WithScalarAnswer.score(~estimate, ~answer)
 | 
				
			||||||
    if numerator < 0.0 {
 | 
					  | DistEstimateScalarAnswer({estimate, answer, prior: Some(prior)}) =>
 | 
				
			||||||
      Operation.PdfInvalidError->Error
 | 
					    WithScalarAnswer.scoreWithPrior(~estimate, ~answer, ~prior)
 | 
				
			||||||
    } else if numerator == 0.0 {
 | 
					  | ScalarEstimateDistAnswer(_) => Operation.NotYetImplemented->Error
 | 
				
			||||||
      infinity->Ok
 | 
					  | ScalarEstimateScalarAnswer({estimate, answer, prior: None}) =>
 | 
				
			||||||
    } else {
 | 
					    TwoScalars.score(~estimate, ~answer)
 | 
				
			||||||
      -.(
 | 
					  | ScalarEstimateScalarAnswer({estimate, answer, prior: Some(prior)}) =>
 | 
				
			||||||
        switch priorPdf {
 | 
					    TwoScalars.scoreWithPrior(~estimate, ~answer, ~prior)
 | 
				
			||||||
        | None => numerator->logFn
 | 
					 | 
				
			||||||
        | Some(f) => {
 | 
					 | 
				
			||||||
            let priorDensityOfAnswer = f(answer)
 | 
					 | 
				
			||||||
            if priorDensityOfAnswer == 0.0 {
 | 
					 | 
				
			||||||
              neg_infinity
 | 
					 | 
				
			||||||
            } else {
 | 
					 | 
				
			||||||
              (numerator /. priorDensityOfAnswer)->logFn
 | 
					 | 
				
			||||||
            }
 | 
					 | 
				
			||||||
          }
 | 
					 | 
				
			||||||
        }
 | 
					 | 
				
			||||||
      )->Ok
 | 
					 | 
				
			||||||
    }
 | 
					 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
}
 | 
					 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue
	
	Block a user