Replaced some scoring functions with functionRegistry functions
This commit is contained in:
parent
652394f535
commit
3d71d7bf4e
|
@ -8,6 +8,7 @@ type rec frType =
|
|||
| FRTypeNumber
|
||||
| FRTypeNumeric
|
||||
| FRTypeDistOrNumber
|
||||
| FRTypeDist
|
||||
| FRTypeLambda
|
||||
| FRTypeRecord(frTypeRecord)
|
||||
| FRTypeDict(frType)
|
||||
|
@ -60,6 +61,7 @@ module FRType = {
|
|||
switch t {
|
||||
| FRTypeNumber => "number"
|
||||
| FRTypeNumeric => "numeric"
|
||||
| FRTypeDist => "distribution"
|
||||
| FRTypeDistOrNumber => "distribution|number"
|
||||
| FRTypeRecord(r) => {
|
||||
let input = ((name, frType): frTypeRecordParam) => `${name}: ${toString(frType)}`
|
||||
|
@ -98,6 +100,7 @@ module FRType = {
|
|||
| (FRTypeDistOrNumber, IEvDistribution(Symbolic(#Float(f)))) =>
|
||||
Some(FRValueDistOrNumber(FRValueNumber(f)))
|
||||
| (FRTypeDistOrNumber, IEvDistribution(f)) => Some(FRValueDistOrNumber(FRValueDist(f)))
|
||||
| (FRTypeDist, IEvDistribution(f)) => Some(FRValueDist(f))
|
||||
| (FRTypeNumeric, IEvNumber(f)) => Some(FRValueNumber(f))
|
||||
| (FRTypeNumeric, IEvDistribution(Symbolic(#Float(f)))) => Some(FRValueNumber(f))
|
||||
| (FRTypeLambda, IEvLambda(f)) => Some(FRValueLambda(f))
|
||||
|
|
|
@ -27,6 +27,12 @@ module Prepare = {
|
|||
| _ => Error(impossibleError)
|
||||
}
|
||||
|
||||
let threeArgs = (inputs: ts): result<ts, err> =>
|
||||
switch inputs {
|
||||
| [FRValueRecord([(_, n1), (_, n2), (_, n3)])] => Ok([n1, n2, n3])
|
||||
| _ => Error(impossibleError)
|
||||
}
|
||||
|
||||
let toArgs = (inputs: ts): result<ts, err> =>
|
||||
switch inputs {
|
||||
| [FRValueRecord(args)] => args->E.A2.fmap(((_, b)) => b)->Ok
|
||||
|
@ -57,6 +63,13 @@ module Prepare = {
|
|||
}
|
||||
}
|
||||
|
||||
let twoDist = (values: ts): result<(DistributionTypes.genericDist, DistributionTypes.genericDist), err> => {
|
||||
switch values {
|
||||
| [FRValueDist(a1), FRValueDist(a2)] => Ok(a1, a2)
|
||||
| _ => Error(impossibleError)
|
||||
}
|
||||
}
|
||||
|
||||
let twoNumbers = (values: ts): result<(float, float), err> => {
|
||||
switch values {
|
||||
| [FRValueNumber(a1), FRValueNumber(a2)] => Ok(a1, a2)
|
||||
|
@ -81,6 +94,9 @@ module Prepare = {
|
|||
module Record = {
|
||||
let twoDistOrNumber = (values: ts): result<(frValueDistOrNumber, frValueDistOrNumber), err> =>
|
||||
values->ToValueArray.Record.twoArgs->E.R.bind(twoDistOrNumber)
|
||||
|
||||
let twoDist = (values: ts): result<(DistributionTypes.genericDist, DistributionTypes.genericDist), err> =>
|
||||
values->ToValueArray.Record.twoArgs->E.R.bind(twoDist)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -49,7 +49,7 @@ let inputsTodist = (inputs: array<FunctionRegistry_Core.frValue>, makeDist) => {
|
|||
expressionValue
|
||||
}
|
||||
|
||||
let registry = [
|
||||
let registryStart = [
|
||||
Function.make(
|
||||
~name="toContinuousPointSet",
|
||||
~definitions=[
|
||||
|
@ -510,3 +510,58 @@ to(5,10)
|
|||
(),
|
||||
),
|
||||
]
|
||||
|
||||
let runScoring = (estimate, answer, prior) => {
|
||||
GenericDist.Score.logScore(~estimate, ~answer, ~prior)
|
||||
->E.R2.fmap(FunctionRegistry_Helpers.Wrappers.evNumber)
|
||||
->E.R2.errMap(DistributionTypes.Error.toString)
|
||||
}
|
||||
|
||||
let scoreFunctions = [
|
||||
Function.make(
|
||||
~name="Score",
|
||||
~definitions=[
|
||||
FnDefinition.make(
|
||||
~name="logScore",
|
||||
~inputs=[
|
||||
FRTypeRecord([
|
||||
("estimate", FRTypeDist),
|
||||
("answer", FRTypeDistOrNumber),
|
||||
("prior", FRTypeDist),
|
||||
]),
|
||||
],
|
||||
~run=(inputs, _) => {
|
||||
switch FunctionRegistry_Helpers.Prepare.ToValueArray.Record.threeArgs(inputs) {
|
||||
| Ok([FRValueDist(estimate), FRValueDistOrNumber(FRValueDist(d)), FRValueDist(prior)]) =>
|
||||
runScoring(estimate, Score_Dist(d), Some(prior))
|
||||
| Ok([
|
||||
FRValueDist(estimate),
|
||||
FRValueDistOrNumber(FRValueNumber(d)),
|
||||
FRValueDist(prior),
|
||||
]) =>
|
||||
runScoring(estimate, Score_Scalar(d), Some(prior))
|
||||
| Error(e) => Error(e)
|
||||
| _ => Error(FunctionRegistry_Helpers.impossibleError)
|
||||
}
|
||||
},
|
||||
),
|
||||
FnDefinition.make(
|
||||
~name="logScore",
|
||||
~inputs=[FRTypeRecord([("estimate", FRTypeDist), ("answer", FRTypeDistOrNumber)])],
|
||||
~run=(inputs, _) => {
|
||||
switch FunctionRegistry_Helpers.Prepare.ToValueArray.Record.twoArgs(inputs) {
|
||||
| Ok([FRValueDist(estimate), FRValueDistOrNumber(FRValueDist(d))]) =>
|
||||
runScoring(estimate, Score_Dist(d), None)
|
||||
| Ok([FRValueDist(estimate), FRValueDistOrNumber(FRValueNumber(d))]) =>
|
||||
runScoring(estimate, Score_Scalar(d), None)
|
||||
| Error(e) => Error(e)
|
||||
| _ => Error(FunctionRegistry_Helpers.impossibleError)
|
||||
}
|
||||
},
|
||||
),
|
||||
],
|
||||
(),
|
||||
),
|
||||
]
|
||||
|
||||
let registry = E.A.append(registryStart, scoreFunctions)
|
||||
|
|
|
@ -223,54 +223,6 @@ let dispatchToGenericOutput = (call: IEV.functionCall, env: DistributionOperatio
|
|||
~env,
|
||||
),
|
||||
)
|
||||
| (
|
||||
"klDivergence",
|
||||
[IEvDistribution(prediction), IEvDistribution(answer), IEvDistribution(prior)],
|
||||
) =>
|
||||
Some(
|
||||
DistributionOperation.run(
|
||||
FromDist(
|
||||
#ToScore(
|
||||
LogScore(DistributionTypes.DistributionOperation.Score_Dist(answer), Some(prior)),
|
||||
),
|
||||
prediction,
|
||||
),
|
||||
~env,
|
||||
),
|
||||
)
|
||||
| (
|
||||
"logScoreWithPointAnswer",
|
||||
[IEvDistribution(prediction), IEvNumber(answer), IEvDistribution(prior)],
|
||||
)
|
||||
| (
|
||||
"logScoreWithPointAnswer",
|
||||
[
|
||||
IEvDistribution(prediction),
|
||||
IEvDistribution(Symbolic(#Float(answer))),
|
||||
IEvDistribution(prior),
|
||||
],
|
||||
) =>
|
||||
DistributionOperation.run(
|
||||
FromDist(
|
||||
#ToScore(
|
||||
LogScore(DistributionTypes.DistributionOperation.Score_Scalar(answer), prior->Some),
|
||||
),
|
||||
prediction,
|
||||
),
|
||||
~env,
|
||||
)->Some
|
||||
| ("logScoreWithPointAnswer", [IEvDistribution(prediction), IEvNumber(answer)])
|
||||
| (
|
||||
"logScoreWithPointAnswer",
|
||||
[IEvDistribution(prediction), IEvDistribution(Symbolic(#Float(answer)))],
|
||||
) =>
|
||||
DistributionOperation.run(
|
||||
FromDist(
|
||||
#ToScore(LogScore(DistributionTypes.DistributionOperation.Score_Scalar(answer), None)),
|
||||
prediction,
|
||||
),
|
||||
~env,
|
||||
)->Some
|
||||
| ("isNormalized", [IEvDistribution(dist)]) => Helpers.toBoolFn(IsNormalized, dist, ~env)
|
||||
| ("toPointSet", [IEvDistribution(dist)]) => Helpers.toDistFn(ToPointSet, dist, ~env)
|
||||
| ("scaleLog", [IEvDistribution(dist)]) =>
|
||||
|
|
Loading…
Reference in New Issue
Block a user