Merge pull request #829 from quantified-uncertainty/scoring-cleanup-three-fixes

Scoring cleanup three fixes
This commit is contained in:
Ozzie Gooen 2022-07-13 10:00:06 -07:00 committed by GitHub
commit 4853df39bb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
17 changed files with 244 additions and 172 deletions

View File

@ -1,7 +1,7 @@
open Jest open Jest
open Expect open Expect
let env: DistributionOperation.env = { let env: GenericDist.env = {
sampleCount: 100, sampleCount: 100,
xyPointLength: 100, xyPointLength: 100,
} }

View File

@ -29,7 +29,7 @@ let {toFloat, toDist, toString, toError, fmap} = module(DistributionOperation.Ou
let fnImage = (theFn, inps) => Js.Array.map(theFn, inps) let fnImage = (theFn, inps) => Js.Array.map(theFn, inps)
let env: DistributionOperation.env = { let env: GenericDist.env = {
sampleCount: MagicNumbers.Environment.defaultSampleCount, sampleCount: MagicNumbers.Environment.defaultSampleCount,
xyPointLength: MagicNumbers.Environment.defaultXYPointLength, xyPointLength: MagicNumbers.Environment.defaultXYPointLength,
} }

View File

@ -4,12 +4,9 @@ type error = DistributionTypes.error
// TODO: It could be great to use a cache for some calculations (basically, do memoization). Also, better analytics/tracking could go a long way. // TODO: It could be great to use a cache for some calculations (basically, do memoization). Also, better analytics/tracking could go a long way.
type env = { type env = GenericDist.env
sampleCount: int,
xyPointLength: int,
}
let defaultEnv = { let defaultEnv:env = {
sampleCount: MagicNumbers.Environment.defaultSampleCount, sampleCount: MagicNumbers.Environment.defaultSampleCount,
xyPointLength: MagicNumbers.Environment.defaultXYPointLength, xyPointLength: MagicNumbers.Environment.defaultXYPointLength,
} }
@ -93,7 +90,7 @@ module OutputLocal = {
} }
} }
let rec run = (~env, functionCallInfo: functionCallInfo): outputType => { let rec run = (~env:env, functionCallInfo: functionCallInfo): outputType => {
let {sampleCount, xyPointLength} = env let {sampleCount, xyPointLength} = env
let reCall = (~env=env, ~functionCallInfo=functionCallInfo, ()) => { let reCall = (~env=env, ~functionCallInfo=functionCallInfo, ()) => {
@ -146,7 +143,7 @@ let rec run = (~env, functionCallInfo: functionCallInfo): outputType => {
} }
| #ToDist(Normalize) => dist->GenericDist.normalize->Dist | #ToDist(Normalize) => dist->GenericDist.normalize->Dist
| #ToScore(LogScore(answer, prior)) => | #ToScore(LogScore(answer, prior)) =>
GenericDist.Score.logScore(~estimate=Score_Dist(dist), ~answer, ~prior) GenericDist.Score.logScore(~estimate=dist, ~answer, ~prior, ~env)
->E.R2.fmap(s => Float(s)) ->E.R2.fmap(s => Float(s))
->OutputLocal.fromResult ->OutputLocal.fromResult
| #ToBool(IsNormalized) => dist->GenericDist.isNormalized->Bool | #ToBool(IsNormalized) => dist->GenericDist.isNormalized->Bool

View File

@ -1,11 +1,5 @@
@genType @genType
type env = { let defaultEnv: GenericDist.env
sampleCount: int,
xyPointLength: int,
}
@genType
let defaultEnv: env
open DistributionTypes open DistributionTypes
@ -19,14 +13,17 @@ type outputType =
| GenDistError(error) | GenDistError(error)
@genType @genType
let run: (~env: env, DistributionTypes.DistributionOperation.genericFunctionCallInfo) => outputType let run: (
~env: GenericDist.env,
DistributionTypes.DistributionOperation.genericFunctionCallInfo,
) => outputType
let runFromDist: ( let runFromDist: (
~env: env, ~env: GenericDist.env,
~functionCallInfo: DistributionTypes.DistributionOperation.fromDist, ~functionCallInfo: DistributionTypes.DistributionOperation.fromDist,
genericDist, genericDist,
) => outputType ) => outputType
let runFromFloat: ( let runFromFloat: (
~env: env, ~env: GenericDist.env,
~functionCallInfo: DistributionTypes.DistributionOperation.fromFloat, ~functionCallInfo: DistributionTypes.DistributionOperation.fromFloat,
float, float,
) => outputType ) => outputType
@ -42,90 +39,147 @@ module Output: {
let toBool: t => option<bool> let toBool: t => option<bool>
let toBoolR: t => result<bool, error> let toBoolR: t => result<bool, error>
let toError: t => option<error> let toError: t => option<error>
let fmap: (~env: env, t, DistributionTypes.DistributionOperation.singleParamaterFunction) => t let fmap: (
~env: GenericDist.env,
t,
DistributionTypes.DistributionOperation.singleParamaterFunction,
) => t
} }
module Constructors: { module Constructors: {
@genType @genType
let mean: (~env: env, genericDist) => result<float, error> let mean: (~env: GenericDist.env, genericDist) => result<float, error>
@genType @genType
let stdev: (~env: env, genericDist) => result<float, error> let stdev: (~env: GenericDist.env, genericDist) => result<float, error>
@genType @genType
let variance: (~env: env, genericDist) => result<float, error> let variance: (~env: GenericDist.env, genericDist) => result<float, error>
@genType @genType
let sample: (~env: env, genericDist) => result<float, error> let sample: (~env: GenericDist.env, genericDist) => result<float, error>
@genType @genType
let cdf: (~env: env, genericDist, float) => result<float, error> let cdf: (~env: GenericDist.env, genericDist, float) => result<float, error>
@genType @genType
let inv: (~env: env, genericDist, float) => result<float, error> let inv: (~env: GenericDist.env, genericDist, float) => result<float, error>
@genType @genType
let pdf: (~env: env, genericDist, float) => result<float, error> let pdf: (~env: GenericDist.env, genericDist, float) => result<float, error>
@genType @genType
let normalize: (~env: env, genericDist) => result<genericDist, error> let normalize: (~env: GenericDist.env, genericDist) => result<genericDist, error>
@genType @genType
let isNormalized: (~env: env, genericDist) => result<bool, error> let isNormalized: (~env: GenericDist.env, genericDist) => result<bool, error>
module LogScore: { module LogScore: {
@genType @genType
let distEstimateDistAnswer: (~env: env, genericDist, genericDist) => result<float, error> let distEstimateDistAnswer: (
@genType ~env: GenericDist.env,
let distEstimateDistAnswerWithPrior: (
~env: env,
genericDist, genericDist,
genericDist, genericDist,
DistributionTypes.DistributionOperation.genericDistOrScalar,
) => result<float, error> ) => result<float, error>
@genType @genType
let distEstimateScalarAnswer: (~env: env, genericDist, float) => result<float, error> let distEstimateDistAnswerWithPrior: (
~env: GenericDist.env,
genericDist,
genericDist,
genericDist,
) => result<float, error>
@genType @genType
let distEstimateScalarAnswerWithPrior: ( let distEstimateScalarAnswer: (
~env: env, ~env: GenericDist.env,
genericDist, genericDist,
float, float,
DistributionTypes.DistributionOperation.genericDistOrScalar, ) => result<float, error>
@genType
let distEstimateScalarAnswerWithPrior: (
~env: GenericDist.env,
genericDist,
float,
genericDist,
) => result<float, error> ) => result<float, error>
} }
@genType @genType
let toPointSet: (~env: env, genericDist) => result<genericDist, error> let toPointSet: (~env: GenericDist.env, genericDist) => result<genericDist, error>
@genType @genType
let toSampleSet: (~env: env, genericDist, int) => result<genericDist, error> let toSampleSet: (~env: GenericDist.env, genericDist, int) => result<genericDist, error>
@genType @genType
let fromSamples: (~env: env, SampleSetDist.t) => result<genericDist, error> let fromSamples: (~env: GenericDist.env, SampleSetDist.t) => result<genericDist, error>
@genType @genType
let truncate: (~env: env, genericDist, option<float>, option<float>) => result<genericDist, error> let truncate: (
~env: GenericDist.env,
genericDist,
option<float>,
option<float>,
) => result<genericDist, error>
@genType @genType
let inspect: (~env: env, genericDist) => result<genericDist, error> let inspect: (~env: GenericDist.env, genericDist) => result<genericDist, error>
@genType @genType
let toString: (~env: env, genericDist) => result<string, error> let toString: (~env: GenericDist.env, genericDist) => result<string, error>
@genType @genType
let toSparkline: (~env: env, genericDist, int) => result<string, error> let toSparkline: (~env: GenericDist.env, genericDist, int) => result<string, error>
@genType @genType
let algebraicAdd: (~env: env, genericDist, genericDist) => result<genericDist, error> let algebraicAdd: (~env: GenericDist.env, genericDist, genericDist) => result<genericDist, error>
@genType @genType
let algebraicMultiply: (~env: env, genericDist, genericDist) => result<genericDist, error> let algebraicMultiply: (
~env: GenericDist.env,
genericDist,
genericDist,
) => result<genericDist, error>
@genType @genType
let algebraicDivide: (~env: env, genericDist, genericDist) => result<genericDist, error> let algebraicDivide: (
~env: GenericDist.env,
genericDist,
genericDist,
) => result<genericDist, error>
@genType @genType
let algebraicSubtract: (~env: env, genericDist, genericDist) => result<genericDist, error> let algebraicSubtract: (
~env: GenericDist.env,
genericDist,
genericDist,
) => result<genericDist, error>
@genType @genType
let algebraicLogarithm: (~env: env, genericDist, genericDist) => result<genericDist, error> let algebraicLogarithm: (
~env: GenericDist.env,
genericDist,
genericDist,
) => result<genericDist, error>
@genType @genType
let algebraicPower: (~env: env, genericDist, genericDist) => result<genericDist, error> let algebraicPower: (
~env: GenericDist.env,
genericDist,
genericDist,
) => result<genericDist, error>
@genType @genType
let scaleLogarithm: (~env: env, genericDist, float) => result<genericDist, error> let scaleLogarithm: (~env: GenericDist.env, genericDist, float) => result<genericDist, error>
@genType @genType
let scaleMultiply: (~env: env, genericDist, float) => result<genericDist, error> let scaleMultiply: (~env: GenericDist.env, genericDist, float) => result<genericDist, error>
@genType @genType
let scalePower: (~env: env, genericDist, float) => result<genericDist, error> let scalePower: (~env: GenericDist.env, genericDist, float) => result<genericDist, error>
@genType @genType
let pointwiseAdd: (~env: env, genericDist, genericDist) => result<genericDist, error> let pointwiseAdd: (~env: GenericDist.env, genericDist, genericDist) => result<genericDist, error>
@genType @genType
let pointwiseMultiply: (~env: env, genericDist, genericDist) => result<genericDist, error> let pointwiseMultiply: (
~env: GenericDist.env,
genericDist,
genericDist,
) => result<genericDist, error>
@genType @genType
let pointwiseDivide: (~env: env, genericDist, genericDist) => result<genericDist, error> let pointwiseDivide: (
~env: GenericDist.env,
genericDist,
genericDist,
) => result<genericDist, error>
@genType @genType
let pointwiseSubtract: (~env: env, genericDist, genericDist) => result<genericDist, error> let pointwiseSubtract: (
~env: GenericDist.env,
genericDist,
genericDist,
) => result<genericDist, error>
@genType @genType
let pointwiseLogarithm: (~env: env, genericDist, genericDist) => result<genericDist, error> let pointwiseLogarithm: (
~env: GenericDist.env,
genericDist,
genericDist,
) => result<genericDist, error>
@genType @genType
let pointwisePower: (~env: env, genericDist, genericDist) => result<genericDist, error> let pointwisePower: (
~env: GenericDist.env,
genericDist,
genericDist,
) => result<genericDist, error>
} }

View File

@ -100,7 +100,7 @@ module DistributionOperation = {
type genericDistOrScalar = Score_Dist(genericDist) | Score_Scalar(float) type genericDistOrScalar = Score_Dist(genericDist) | Score_Scalar(float)
type toScore = LogScore(genericDistOrScalar, option<genericDistOrScalar>) type toScore = LogScore(genericDistOrScalar, option<genericDist>)
type fromFloat = [ type fromFloat = [
| #ToFloat(toFloat) | #ToFloat(toFloat)

View File

@ -6,6 +6,11 @@ type toSampleSetFn = t => result<SampleSetDist.t, error>
type scaleMultiplyFn = (t, float) => result<t, error> type scaleMultiplyFn = (t, float) => result<t, error>
type pointwiseAddFn = (t, t) => result<t, error> type pointwiseAddFn = (t, t) => result<t, error>
type env = {
sampleCount: int,
xyPointLength: int,
}
let isPointSet = (t: t) => let isPointSet = (t: t) =>
switch t { switch t {
| PointSet(_) => true | PointSet(_) => true
@ -133,36 +138,33 @@ let toPointSet = (
module Score = { module Score = {
type genericDistOrScalar = DistributionTypes.DistributionOperation.genericDistOrScalar type genericDistOrScalar = DistributionTypes.DistributionOperation.genericDistOrScalar
type pointSet_ScoreDistOrScalar = PSDist(PointSetTypes.pointSetDist) | PSScalar(float)
let argsMake = ( let argsMake = (~esti: t, ~answ: genericDistOrScalar, ~prior: option<t>, ~env: env): result<
~esti: genericDistOrScalar, PointSetDist_Scoring.scoreArgs,
~answ: genericDistOrScalar, error,
~prior: option<genericDistOrScalar>, > => {
): result<PointSetDist_Scoring.scoreArgs, error> => {
let toPointSetFn = t => let toPointSetFn = t =>
toPointSet( toPointSet(
t, t,
~xyPointLength=MagicNumbers.Environment.defaultXYPointLength, ~xyPointLength=env.xyPointLength,
~sampleCount=MagicNumbers.Environment.defaultSampleCount, ~sampleCount=env.sampleCount,
~xSelection=#ByWeight, ~xSelection=#ByWeight,
(), (),
) )
let prior': option<result<pointSet_ScoreDistOrScalar, error>> = switch prior { let prior': option<result<PointSetTypes.pointSetDist, error>> = switch prior {
| None => None | None => None
| Some(Score_Dist(d)) => toPointSetFn(d)->E.R.bind(x => x->PSDist->Ok)->Some | Some(d) => toPointSetFn(d)->Some
| Some(Score_Scalar(s)) => s->PSScalar->Ok->Some
} }
let twoDists = (~toPointSetFn, esti': t, answ': t): result< let twoDists = (~toPointSetFn, esti': t, answ': t): result<
(PointSetTypes.pointSetDist, PointSetTypes.pointSetDist), (PointSetTypes.pointSetDist, PointSetTypes.pointSetDist),
error, error,
> => E.R.merge(toPointSetFn(esti'), toPointSetFn(answ')) > => E.R.merge(toPointSetFn(esti'), toPointSetFn(answ'))
switch (esti, answ, prior') { switch (esti, answ, prior') {
| (Score_Dist(esti'), Score_Dist(answ'), None) => | (esti', Score_Dist(answ'), None) =>
twoDists(~toPointSetFn, esti', answ')->E.R2.fmap(((esti'', answ'')) => twoDists(~toPointSetFn, esti', answ')->E.R2.fmap(((esti'', answ'')) =>
{estimate: esti'', answer: answ'', prior: None}->PointSetDist_Scoring.DistAnswer {estimate: esti'', answer: answ'', prior: None}->PointSetDist_Scoring.DistAnswer
) )
| (Score_Dist(esti'), Score_Dist(answ'), Some(Ok(PSDist(prior'')))) => | (esti', Score_Dist(answ'), Some(Ok(prior''))) =>
twoDists(~toPointSetFn, esti', answ')->E.R2.fmap(((esti'', answ'')) => twoDists(~toPointSetFn, esti', answ')->E.R2.fmap(((esti'', answ'')) =>
{ {
estimate: esti'', estimate: esti'',
@ -170,8 +172,7 @@ module Score = {
prior: Some(prior''), prior: Some(prior''),
}->PointSetDist_Scoring.DistAnswer }->PointSetDist_Scoring.DistAnswer
) )
| (Score_Dist(_), _, Some(Ok(PSScalar(_)))) => DistributionTypes.Unreachable->Error | (esti', Score_Scalar(answ'), None) =>
| (Score_Dist(esti'), Score_Scalar(answ'), None) =>
toPointSetFn(esti')->E.R2.fmap(esti'' => toPointSetFn(esti')->E.R2.fmap(esti'' =>
{ {
estimate: esti'', estimate: esti'',
@ -179,7 +180,7 @@ module Score = {
prior: None, prior: None,
}->PointSetDist_Scoring.ScalarAnswer }->PointSetDist_Scoring.ScalarAnswer
) )
| (Score_Dist(esti'), Score_Scalar(answ'), Some(Ok(PSDist(prior'')))) => | (esti', Score_Scalar(answ'), Some(Ok(prior''))) =>
toPointSetFn(esti')->E.R2.fmap(esti'' => toPointSetFn(esti')->E.R2.fmap(esti'' =>
{ {
estimate: esti'', estimate: esti'',
@ -187,20 +188,17 @@ module Score = {
prior: Some(prior''), prior: Some(prior''),
}->PointSetDist_Scoring.ScalarAnswer }->PointSetDist_Scoring.ScalarAnswer
) )
| (Score_Scalar(_), Score_Dist(_), None) => NotYetImplemented->Error
| (Score_Scalar(_), Score_Dist(_), Some(Ok(PSScalar(_)))) => NotYetImplemented->Error
| (Score_Scalar(_), _, Some(Ok(PSDist(_)))) => DistributionTypes.Unreachable->Error
| (Score_Scalar(_), Score_Scalar(_), _) => NotYetImplemented->Error
| (_, _, Some(Error(err))) => err->Error | (_, _, Some(Error(err))) => err->Error
} }
} }
let logScore = ( let logScore = (
~estimate: genericDistOrScalar, ~estimate: t,
~answer: genericDistOrScalar, ~answer: genericDistOrScalar,
~prior: option<genericDistOrScalar>, ~prior: option<t>,
~env: env,
): result<float, error> => ): result<float, error> =>
argsMake(~esti=estimate, ~answ=answer, ~prior)->E.R.bind(x => argsMake(~esti=estimate, ~answ=answer, ~prior, ~env)->E.R.bind(x =>
x->PointSetDist.logScore->E.R2.errMap(y => DistributionTypes.OperationError(y)) x->PointSetDist.logScore->E.R2.errMap(y => DistributionTypes.OperationError(y))
) )
} }

View File

@ -5,6 +5,9 @@ type toSampleSetFn = t => result<SampleSetDist.t, error>
type scaleMultiplyFn = (t, float) => result<t, error> type scaleMultiplyFn = (t, float) => result<t, error>
type pointwiseAddFn = (t, t) => result<t, error> type pointwiseAddFn = (t, t) => result<t, error>
@genType
type env = {sampleCount: int, xyPointLength: int}
let sampleN: (t, int) => array<float> let sampleN: (t, int) => array<float>
let sample: t => float let sample: t => float
@ -26,9 +29,10 @@ let toFloatOperation: (
module Score: { module Score: {
let logScore: ( let logScore: (
~estimate: DistributionTypes.DistributionOperation.genericDistOrScalar, ~estimate: t,
~answer: DistributionTypes.DistributionOperation.genericDistOrScalar, ~answer: DistributionTypes.DistributionOperation.genericDistOrScalar,
~prior: option<DistributionTypes.DistributionOperation.genericDistOrScalar>, ~prior: option<t>,
~env: env
) => result<float, error> ) => result<float, error>
} }

View File

@ -19,7 +19,7 @@ module WithDistAnswer = {
float, float,
Operation.Error.t, Operation.Error.t,
> => > =>
// We decided that negative infinity, not an error at answerElement = 0.0, is a desirable value. // We decided that 0.0, not an error at answerElement = 0.0, is a desirable value.
if answerElement == 0.0 { if answerElement == 0.0 {
Ok(0.0) Ok(0.0)
} else if estimateElement == 0.0 { } else if estimateElement == 0.0 {

View File

@ -8,6 +8,7 @@ type rec frType =
| FRTypeNumber | FRTypeNumber
| FRTypeNumeric | FRTypeNumeric
| FRTypeDistOrNumber | FRTypeDistOrNumber
| FRTypeDist
| FRTypeLambda | FRTypeLambda
| FRTypeRecord(frTypeRecord) | FRTypeRecord(frTypeRecord)
| FRTypeDict(frType) | FRTypeDict(frType)
@ -41,7 +42,7 @@ and frValueDistOrNumber = FRValueNumber(float) | FRValueDist(DistributionTypes.g
type fnDefinition = { type fnDefinition = {
name: string, name: string,
inputs: array<frType>, inputs: array<frType>,
run: (array<frValue>, DistributionOperation.env) => result<internalExpressionValue, string>, run: (array<frValue>, GenericDist.env) => result<internalExpressionValue, string>,
} }
type function = { type function = {
@ -60,6 +61,7 @@ module FRType = {
switch t { switch t {
| FRTypeNumber => "number" | FRTypeNumber => "number"
| FRTypeNumeric => "numeric" | FRTypeNumeric => "numeric"
| FRTypeDist => "distribution"
| FRTypeDistOrNumber => "distribution|number" | FRTypeDistOrNumber => "distribution|number"
| FRTypeRecord(r) => { | FRTypeRecord(r) => {
let input = ((name, frType): frTypeRecordParam) => `${name}: ${toString(frType)}` let input = ((name, frType): frTypeRecordParam) => `${name}: ${toString(frType)}`
@ -98,6 +100,7 @@ module FRType = {
| (FRTypeDistOrNumber, IEvDistribution(Symbolic(#Float(f)))) => | (FRTypeDistOrNumber, IEvDistribution(Symbolic(#Float(f)))) =>
Some(FRValueDistOrNumber(FRValueNumber(f))) Some(FRValueDistOrNumber(FRValueNumber(f)))
| (FRTypeDistOrNumber, IEvDistribution(f)) => Some(FRValueDistOrNumber(FRValueDist(f))) | (FRTypeDistOrNumber, IEvDistribution(f)) => Some(FRValueDistOrNumber(FRValueDist(f)))
| (FRTypeDist, IEvDistribution(f)) => Some(FRValueDist(f))
| (FRTypeNumeric, IEvNumber(f)) => Some(FRValueNumber(f)) | (FRTypeNumeric, IEvNumber(f)) => Some(FRValueNumber(f))
| (FRTypeNumeric, IEvDistribution(Symbolic(#Float(f)))) => Some(FRValueNumber(f)) | (FRTypeNumeric, IEvDistribution(Symbolic(#Float(f)))) => Some(FRValueNumber(f))
| (FRTypeLambda, IEvLambda(f)) => Some(FRValueLambda(f)) | (FRTypeLambda, IEvLambda(f)) => Some(FRValueLambda(f))
@ -319,7 +322,7 @@ module FnDefinition = {
t.name ++ `(${inputs})` t.name ++ `(${inputs})`
} }
let run = (t: t, args: array<internalExpressionValue>, env: DistributionOperation.env) => { let run = (t: t, args: array<internalExpressionValue>, env: GenericDist.env) => {
let argValues = FRType.matchWithExpressionValueArray(t.inputs, args) let argValues = FRType.matchWithExpressionValueArray(t.inputs, args)
switch argValues { switch argValues {
| Some(values) => t.run(values, env) | Some(values) => t.run(values, env)
@ -374,7 +377,7 @@ module Registry = {
~registry: registry, ~registry: registry,
~fnName: string, ~fnName: string,
~args: array<internalExpressionValue>, ~args: array<internalExpressionValue>,
~env: DistributionOperation.env, ~env: GenericDist.env,
) => { ) => {
let matchToDef = m => Matcher.Registry.matchToDef(registry, m) let matchToDef = m => Matcher.Registry.matchToDef(registry, m)
//Js.log(toSimple(registry)) //Js.log(toSimple(registry))

View File

@ -27,6 +27,12 @@ module Prepare = {
| _ => Error(impossibleError) | _ => Error(impossibleError)
} }
let threeArgs = (inputs: ts): result<ts, err> =>
switch inputs {
| [FRValueRecord([(_, n1), (_, n2), (_, n3)])] => Ok([n1, n2, n3])
| _ => Error(impossibleError)
}
let toArgs = (inputs: ts): result<ts, err> => let toArgs = (inputs: ts): result<ts, err> =>
switch inputs { switch inputs {
| [FRValueRecord(args)] => args->E.A2.fmap(((_, b)) => b)->Ok | [FRValueRecord(args)] => args->E.A2.fmap(((_, b)) => b)->Ok
@ -57,6 +63,13 @@ module Prepare = {
} }
} }
let twoDist = (values: ts): result<(DistributionTypes.genericDist, DistributionTypes.genericDist), err> => {
switch values {
| [FRValueDist(a1), FRValueDist(a2)] => Ok(a1, a2)
| _ => Error(impossibleError)
}
}
let twoNumbers = (values: ts): result<(float, float), err> => { let twoNumbers = (values: ts): result<(float, float), err> => {
switch values { switch values {
| [FRValueNumber(a1), FRValueNumber(a2)] => Ok(a1, a2) | [FRValueNumber(a1), FRValueNumber(a2)] => Ok(a1, a2)
@ -81,6 +94,9 @@ module Prepare = {
module Record = { module Record = {
let twoDistOrNumber = (values: ts): result<(frValueDistOrNumber, frValueDistOrNumber), err> => let twoDistOrNumber = (values: ts): result<(frValueDistOrNumber, frValueDistOrNumber), err> =>
values->ToValueArray.Record.twoArgs->E.R.bind(twoDistOrNumber) values->ToValueArray.Record.twoArgs->E.R.bind(twoDistOrNumber)
let twoDist = (values: ts): result<(DistributionTypes.genericDist, DistributionTypes.genericDist), err> =>
values->ToValueArray.Record.twoArgs->E.R.bind(twoDist)
} }
} }
@ -128,7 +144,7 @@ module Prepare = {
module Process = { module Process = {
module DistOrNumberToDist = { module DistOrNumberToDist = {
module Helpers = { module Helpers = {
let toSampleSet = (r, env: DistributionOperation.env) => let toSampleSet = (r, env: GenericDist.env) =>
GenericDist.toSampleSetDist(r, env.sampleCount) GenericDist.toSampleSetDist(r, env.sampleCount)
let mapFnResult = r => let mapFnResult = r =>
@ -166,7 +182,7 @@ module Process = {
let oneValue = ( let oneValue = (
~fn: float => result<DistributionTypes.genericDist, string>, ~fn: float => result<DistributionTypes.genericDist, string>,
~value: frValueDistOrNumber, ~value: frValueDistOrNumber,
~env: DistributionOperation.env, ~env: GenericDist.env,
): result<DistributionTypes.genericDist, string> => { ): result<DistributionTypes.genericDist, string> => {
switch value { switch value {
| FRValueNumber(a1) => fn(a1) | FRValueNumber(a1) => fn(a1)
@ -179,7 +195,7 @@ module Process = {
let twoValues = ( let twoValues = (
~fn: ((float, float)) => result<DistributionTypes.genericDist, string>, ~fn: ((float, float)) => result<DistributionTypes.genericDist, string>,
~values: (frValueDistOrNumber, frValueDistOrNumber), ~values: (frValueDistOrNumber, frValueDistOrNumber),
~env: DistributionOperation.env, ~env: GenericDist.env,
): result<DistributionTypes.genericDist, string> => { ): result<DistributionTypes.genericDist, string> => {
switch values { switch values {
| (FRValueNumber(a1), FRValueNumber(a2)) => fn((a1, a2)) | (FRValueNumber(a1), FRValueNumber(a2)) => fn((a1, a2))

View File

@ -49,7 +49,7 @@ let inputsTodist = (inputs: array<FunctionRegistry_Core.frValue>, makeDist) => {
expressionValue expressionValue
} }
let registry = [ let registryStart = [
Function.make( Function.make(
~name="toContinuousPointSet", ~name="toContinuousPointSet",
~definitions=[ ~definitions=[
@ -510,3 +510,67 @@ to(5,10)
(), (),
), ),
] ]
let runScoring = (estimate, answer, prior, env) => {
GenericDist.Score.logScore(~estimate, ~answer, ~prior, ~env)
->E.R2.fmap(FunctionRegistry_Helpers.Wrappers.evNumber)
->E.R2.errMap(DistributionTypes.Error.toString)
}
let scoreFunctions = [
Function.make(
~name="Score",
~definitions=[
FnDefinition.make(
~name="logScore",
~inputs=[
FRTypeRecord([
("estimate", FRTypeDist),
("answer", FRTypeDistOrNumber),
("prior", FRTypeDist),
]),
],
~run=(inputs, env) => {
switch FunctionRegistry_Helpers.Prepare.ToValueArray.Record.threeArgs(inputs) {
| Ok([FRValueDist(estimate), FRValueDistOrNumber(FRValueDist(d)), FRValueDist(prior)]) =>
runScoring(estimate, Score_Dist(d), Some(prior), env)
| Ok([
FRValueDist(estimate),
FRValueDistOrNumber(FRValueNumber(d)),
FRValueDist(prior),
]) =>
runScoring(estimate, Score_Scalar(d), Some(prior), env)
| Error(e) => Error(e)
| _ => Error(FunctionRegistry_Helpers.impossibleError)
}
},
),
FnDefinition.make(
~name="logScore",
~inputs=[FRTypeRecord([("estimate", FRTypeDist), ("answer", FRTypeDistOrNumber)])],
~run=(inputs, env) => {
switch FunctionRegistry_Helpers.Prepare.ToValueArray.Record.twoArgs(inputs) {
| Ok([FRValueDist(estimate), FRValueDistOrNumber(FRValueDist(d))]) =>
runScoring(estimate, Score_Dist(d), None, env)
| Ok([FRValueDist(estimate), FRValueDistOrNumber(FRValueNumber(d))]) =>
runScoring(estimate, Score_Scalar(d), None, env)
| Error(e) => Error(e)
| _ => Error(FunctionRegistry_Helpers.impossibleError)
}
},
),
FnDefinition.make(~name="klDivergence", ~inputs=[FRTypeDist, FRTypeDist], ~run=(
inputs,
env,
) => {
switch inputs {
| [FRValueDist(estimate), FRValueDist(d)] => runScoring(estimate, Score_Dist(d), None, env)
| _ => Error(FunctionRegistry_Helpers.impossibleError)
}
}),
],
(),
),
]
let registry = E.A.append(registryStart, scoreFunctions)

View File

@ -1,7 +1,7 @@
module IEV = ReducerInterface_InternalExpressionValue module IEV = ReducerInterface_InternalExpressionValue
type internalExpressionValue = IEV.t type internalExpressionValue = IEV.t
let dispatch = (call: IEV.functionCall, _: DistributionOperation.env): option< let dispatch = (call: IEV.functionCall, _: GenericDist.env): option<
result<internalExpressionValue, QuriSquiggleLang.Reducer_ErrorValue.errorValue>, result<internalExpressionValue, QuriSquiggleLang.Reducer_ErrorValue.errorValue>,
> => { > => {
switch call { switch call {

View File

@ -1,7 +1,7 @@
module IEV = ReducerInterface_InternalExpressionValue module IEV = ReducerInterface_InternalExpressionValue
type internalExpressionValue = IEV.t type internalExpressionValue = IEV.t
let dispatch = (call: IEV.functionCall, _: DistributionOperation.env): option< let dispatch = (call: IEV.functionCall, _: GenericDist.env): option<
result<internalExpressionValue, QuriSquiggleLang.Reducer_ErrorValue.errorValue>, result<internalExpressionValue, QuriSquiggleLang.Reducer_ErrorValue.errorValue>,
> => { > => {
switch call { switch call {

View File

@ -86,7 +86,7 @@ let toStringResult = x =>
} }
@genType @genType
type environment = DistributionOperation.env type environment = GenericDist.env
@genType @genType
let defaultEnvironment: environment = DistributionOperation.defaultEnv let defaultEnvironment: environment = DistributionOperation.defaultEnv

View File

@ -32,7 +32,7 @@ module Helpers = {
let toFloatFn = ( let toFloatFn = (
fnCall: DistributionTypes.DistributionOperation.toFloat, fnCall: DistributionTypes.DistributionOperation.toFloat,
dist: DistributionTypes.genericDist, dist: DistributionTypes.genericDist,
~env: DistributionOperation.env, ~env: GenericDist.env,
) => { ) => {
FromDist(#ToFloat(fnCall), dist)->DistributionOperation.run(~env)->Some FromDist(#ToFloat(fnCall), dist)->DistributionOperation.run(~env)->Some
} }
@ -40,7 +40,7 @@ module Helpers = {
let toStringFn = ( let toStringFn = (
fnCall: DistributionTypes.DistributionOperation.toString, fnCall: DistributionTypes.DistributionOperation.toString,
dist: DistributionTypes.genericDist, dist: DistributionTypes.genericDist,
~env: DistributionOperation.env, ~env: GenericDist.env,
) => { ) => {
FromDist(#ToString(fnCall), dist)->DistributionOperation.run(~env)->Some FromDist(#ToString(fnCall), dist)->DistributionOperation.run(~env)->Some
} }
@ -48,7 +48,7 @@ module Helpers = {
let toBoolFn = ( let toBoolFn = (
fnCall: DistributionTypes.DistributionOperation.toBool, fnCall: DistributionTypes.DistributionOperation.toBool,
dist: DistributionTypes.genericDist, dist: DistributionTypes.genericDist,
~env: DistributionOperation.env, ~env: GenericDist.env,
) => { ) => {
FromDist(#ToBool(fnCall), dist)->DistributionOperation.run(~env)->Some FromDist(#ToBool(fnCall), dist)->DistributionOperation.run(~env)->Some
} }
@ -56,12 +56,12 @@ module Helpers = {
let toDistFn = ( let toDistFn = (
fnCall: DistributionTypes.DistributionOperation.toDist, fnCall: DistributionTypes.DistributionOperation.toDist,
dist, dist,
~env: DistributionOperation.env, ~env: GenericDist.env,
) => { ) => {
FromDist(#ToDist(fnCall), dist)->DistributionOperation.run(~env)->Some FromDist(#ToDist(fnCall), dist)->DistributionOperation.run(~env)->Some
} }
let twoDiststoDistFn = (direction, arithmetic, dist1, dist2, ~env: DistributionOperation.env) => { let twoDiststoDistFn = (direction, arithmetic, dist1, dist2, ~env: GenericDist.env) => {
FromDist( FromDist(
#ToDistCombination(direction, arithmeticMap(arithmetic), #Dist(dist2)), #ToDistCombination(direction, arithmeticMap(arithmetic), #Dist(dist2)),
dist1, dist1,
@ -97,7 +97,7 @@ module Helpers = {
let mixtureWithGivenWeights = ( let mixtureWithGivenWeights = (
distributions: array<DistributionTypes.genericDist>, distributions: array<DistributionTypes.genericDist>,
weights: array<float>, weights: array<float>,
~env: DistributionOperation.env, ~env: GenericDist.env,
): DistributionOperation.outputType => ): DistributionOperation.outputType =>
E.A.length(distributions) == E.A.length(weights) E.A.length(distributions) == E.A.length(weights)
? Mixture(Belt.Array.zip(distributions, weights))->DistributionOperation.run(~env) ? Mixture(Belt.Array.zip(distributions, weights))->DistributionOperation.run(~env)
@ -107,7 +107,7 @@ module Helpers = {
let mixtureWithDefaultWeights = ( let mixtureWithDefaultWeights = (
distributions: array<DistributionTypes.genericDist>, distributions: array<DistributionTypes.genericDist>,
~env: DistributionOperation.env, ~env: GenericDist.env,
): DistributionOperation.outputType => { ): DistributionOperation.outputType => {
let length = E.A.length(distributions) let length = E.A.length(distributions)
let weights = Belt.Array.make(length, 1.0 /. Belt.Int.toFloat(length)) let weights = Belt.Array.make(length, 1.0 /. Belt.Int.toFloat(length))
@ -116,7 +116,7 @@ module Helpers = {
let mixture = ( let mixture = (
args: array<internalExpressionValue>, args: array<internalExpressionValue>,
~env: DistributionOperation.env, ~env: GenericDist.env,
): DistributionOperation.outputType => { ): DistributionOperation.outputType => {
let error = (err: string): DistributionOperation.outputType => let error = (err: string): DistributionOperation.outputType =>
err->DistributionTypes.ArgumentError->GenDistError err->DistributionTypes.ArgumentError->GenDistError
@ -173,7 +173,7 @@ module SymbolicConstructors = {
} }
} }
let dispatchToGenericOutput = (call: IEV.functionCall, env: DistributionOperation.env): option< let dispatchToGenericOutput = (call: IEV.functionCall, env: GenericDist.env): option<
DistributionOperation.outputType, DistributionOperation.outputType,
> => { > => {
let (fnName, args) = call let (fnName, args) = call
@ -213,70 +213,6 @@ let dispatchToGenericOutput = (call: IEV.functionCall, env: DistributionOperatio
~env, ~env,
)->Some )->Some
| ("normalize", [IEvDistribution(dist)]) => Helpers.toDistFn(Normalize, dist, ~env) | ("normalize", [IEvDistribution(dist)]) => Helpers.toDistFn(Normalize, dist, ~env)
| ("klDivergence", [IEvDistribution(prediction), IEvDistribution(answer)]) =>
Some(
DistributionOperation.run(
FromDist(
#ToScore(LogScore(DistributionTypes.DistributionOperation.Score_Dist(answer), None)),
prediction,
),
~env,
),
)
| (
"klDivergence",
[IEvDistribution(prediction), IEvDistribution(answer), IEvDistribution(prior)],
) =>
Some(
DistributionOperation.run(
FromDist(
#ToScore(
LogScore(
DistributionTypes.DistributionOperation.Score_Dist(answer),
Some(DistributionTypes.DistributionOperation.Score_Dist(prior)),
),
),
prediction,
),
~env,
),
)
| (
"logScoreWithPointAnswer",
[IEvDistribution(prediction), IEvNumber(answer), IEvDistribution(prior)],
)
| (
"logScoreWithPointAnswer",
[
IEvDistribution(prediction),
IEvDistribution(Symbolic(#Float(answer))),
IEvDistribution(prior),
],
) =>
DistributionOperation.run(
FromDist(
#ToScore(
LogScore(
DistributionTypes.DistributionOperation.Score_Scalar(answer),
DistributionTypes.DistributionOperation.Score_Dist(prior)->Some,
),
),
prediction,
),
~env,
)->Some
| ("logScoreWithPointAnswer", [IEvDistribution(prediction), IEvNumber(answer)])
| (
"logScoreWithPointAnswer",
[IEvDistribution(prediction), IEvDistribution(Symbolic(#Float(answer)))],
) =>
DistributionOperation.run(
FromDist(
#ToScore(LogScore(DistributionTypes.DistributionOperation.Score_Scalar(answer), None)),
prediction,
),
~env,
)->Some
| ("isNormalized", [IEvDistribution(dist)]) => Helpers.toBoolFn(IsNormalized, dist, ~env) | ("isNormalized", [IEvDistribution(dist)]) => Helpers.toBoolFn(IsNormalized, dist, ~env)
| ("toPointSet", [IEvDistribution(dist)]) => Helpers.toDistFn(ToPointSet, dist, ~env) | ("toPointSet", [IEvDistribution(dist)]) => Helpers.toDistFn(ToPointSet, dist, ~env)
| ("scaleLog", [IEvDistribution(dist)]) => | ("scaleLog", [IEvDistribution(dist)]) =>

View File

@ -24,7 +24,7 @@ module ScientificUnit = {
} }
} }
let dispatch = (call: IEV.functionCall, _: DistributionOperation.env): option< let dispatch = (call: IEV.functionCall, _: GenericDist.env): option<
result<internalExpressionValue, QuriSquiggleLang.Reducer_ErrorValue.errorValue>, result<internalExpressionValue, QuriSquiggleLang.Reducer_ErrorValue.errorValue>,
> => { > => {
switch call { switch call {

View File

@ -8,7 +8,7 @@ The below few seem to work fine. In the future there's definitely more work to d
*/ */
@genType @genType
type samplingParams = DistributionOperation.env type samplingParams = GenericDist.env
@genType @genType
type genericDist = DistributionTypes.genericDist type genericDist = DistributionTypes.genericDist