Replaced some scoring functions with functionRegistry functions

This commit is contained in:
Ozzie Gooen 2022-07-13 08:56:22 -07:00
parent 652394f535
commit 3d71d7bf4e
4 changed files with 75 additions and 49 deletions

View File

@ -8,6 +8,7 @@ type rec frType =
| FRTypeNumber
| FRTypeNumeric
| FRTypeDistOrNumber
| FRTypeDist
| FRTypeLambda
| FRTypeRecord(frTypeRecord)
| FRTypeDict(frType)
@ -60,6 +61,7 @@ module FRType = {
switch t {
| FRTypeNumber => "number"
| FRTypeNumeric => "numeric"
| FRTypeDist => "distribution"
| FRTypeDistOrNumber => "distribution|number"
| FRTypeRecord(r) => {
let input = ((name, frType): frTypeRecordParam) => `${name}: ${toString(frType)}`
@ -98,6 +100,7 @@ module FRType = {
| (FRTypeDistOrNumber, IEvDistribution(Symbolic(#Float(f)))) =>
Some(FRValueDistOrNumber(FRValueNumber(f)))
| (FRTypeDistOrNumber, IEvDistribution(f)) => Some(FRValueDistOrNumber(FRValueDist(f)))
| (FRTypeDist, IEvDistribution(f)) => Some(FRValueDist(f))
| (FRTypeNumeric, IEvNumber(f)) => Some(FRValueNumber(f))
| (FRTypeNumeric, IEvDistribution(Symbolic(#Float(f)))) => Some(FRValueNumber(f))
| (FRTypeLambda, IEvLambda(f)) => Some(FRValueLambda(f))

View File

@ -27,6 +27,12 @@ module Prepare = {
| _ => Error(impossibleError)
}
let threeArgs = (inputs: ts): result<ts, err> =>
switch inputs {
| [FRValueRecord([(_, n1), (_, n2), (_, n3)])] => Ok([n1, n2, n3])
| _ => Error(impossibleError)
}
let toArgs = (inputs: ts): result<ts, err> =>
switch inputs {
| [FRValueRecord(args)] => args->E.A2.fmap(((_, b)) => b)->Ok
@ -57,6 +63,13 @@ module Prepare = {
}
}
let twoDist = (values: ts): result<(DistributionTypes.genericDist, DistributionTypes.genericDist), err> => {
switch values {
| [FRValueDist(a1), FRValueDist(a2)] => Ok(a1, a2)
| _ => Error(impossibleError)
}
}
let twoNumbers = (values: ts): result<(float, float), err> => {
switch values {
| [FRValueNumber(a1), FRValueNumber(a2)] => Ok(a1, a2)
@ -81,6 +94,9 @@ module Prepare = {
module Record = {
let twoDistOrNumber = (values: ts): result<(frValueDistOrNumber, frValueDistOrNumber), err> =>
values->ToValueArray.Record.twoArgs->E.R.bind(twoDistOrNumber)
let twoDist = (values: ts): result<(DistributionTypes.genericDist, DistributionTypes.genericDist), err> =>
values->ToValueArray.Record.twoArgs->E.R.bind(twoDist)
}
}

View File

@ -49,7 +49,7 @@ let inputsTodist = (inputs: array<FunctionRegistry_Core.frValue>, makeDist) => {
expressionValue
}
let registry = [
let registryStart = [
Function.make(
~name="toContinuousPointSet",
~definitions=[
@ -510,3 +510,58 @@ to(5,10)
(),
),
]
let runScoring = (estimate, answer, prior) => {
GenericDist.Score.logScore(~estimate, ~answer, ~prior)
->E.R2.fmap(FunctionRegistry_Helpers.Wrappers.evNumber)
->E.R2.errMap(DistributionTypes.Error.toString)
}
let scoreFunctions = [
Function.make(
~name="Score",
~definitions=[
FnDefinition.make(
~name="logScore",
~inputs=[
FRTypeRecord([
("estimate", FRTypeDist),
("answer", FRTypeDistOrNumber),
("prior", FRTypeDist),
]),
],
~run=(inputs, _) => {
switch FunctionRegistry_Helpers.Prepare.ToValueArray.Record.threeArgs(inputs) {
| Ok([FRValueDist(estimate), FRValueDistOrNumber(FRValueDist(d)), FRValueDist(prior)]) =>
runScoring(estimate, Score_Dist(d), Some(prior))
| Ok([
FRValueDist(estimate),
FRValueDistOrNumber(FRValueNumber(d)),
FRValueDist(prior),
]) =>
runScoring(estimate, Score_Scalar(d), Some(prior))
| Error(e) => Error(e)
| _ => Error(FunctionRegistry_Helpers.impossibleError)
}
},
),
FnDefinition.make(
~name="logScore",
~inputs=[FRTypeRecord([("estimate", FRTypeDist), ("answer", FRTypeDistOrNumber)])],
~run=(inputs, _) => {
switch FunctionRegistry_Helpers.Prepare.ToValueArray.Record.twoArgs(inputs) {
| Ok([FRValueDist(estimate), FRValueDistOrNumber(FRValueDist(d))]) =>
runScoring(estimate, Score_Dist(d), None)
| Ok([FRValueDist(estimate), FRValueDistOrNumber(FRValueNumber(d))]) =>
runScoring(estimate, Score_Scalar(d), None)
| Error(e) => Error(e)
| _ => Error(FunctionRegistry_Helpers.impossibleError)
}
},
),
],
(),
),
]
let registry = E.A.append(registryStart, scoreFunctions)

View File

@ -223,54 +223,6 @@ let dispatchToGenericOutput = (call: IEV.functionCall, env: DistributionOperatio
~env,
),
)
| (
"klDivergence",
[IEvDistribution(prediction), IEvDistribution(answer), IEvDistribution(prior)],
) =>
Some(
DistributionOperation.run(
FromDist(
#ToScore(
LogScore(DistributionTypes.DistributionOperation.Score_Dist(answer), Some(prior)),
),
prediction,
),
~env,
),
)
| (
"logScoreWithPointAnswer",
[IEvDistribution(prediction), IEvNumber(answer), IEvDistribution(prior)],
)
| (
"logScoreWithPointAnswer",
[
IEvDistribution(prediction),
IEvDistribution(Symbolic(#Float(answer))),
IEvDistribution(prior),
],
) =>
DistributionOperation.run(
FromDist(
#ToScore(
LogScore(DistributionTypes.DistributionOperation.Score_Scalar(answer), prior->Some),
),
prediction,
),
~env,
)->Some
| ("logScoreWithPointAnswer", [IEvDistribution(prediction), IEvNumber(answer)])
| (
"logScoreWithPointAnswer",
[IEvDistribution(prediction), IEvDistribution(Symbolic(#Float(answer)))],
) =>
DistributionOperation.run(
FromDist(
#ToScore(LogScore(DistributionTypes.DistributionOperation.Score_Scalar(answer), None)),
prediction,
),
~env,
)->Some
| ("isNormalized", [IEvDistribution(dist)]) => Helpers.toBoolFn(IsNormalized, dist, ~env)
| ("toPointSet", [IEvDistribution(dist)]) => Helpers.toDistFn(ToPointSet, dist, ~env)
| ("scaleLog", [IEvDistribution(dist)]) =>