Touchups for FunctionRegistry distTwo

This commit is contained in:
Ozzie Gooen 2022-05-19 09:25:34 -04:00
parent 80a6c56efc
commit 50a5ef2498
4 changed files with 73 additions and 16 deletions

View File

@ -24,6 +24,7 @@ let isSymbolic = (t: t) =>
| _ => false
}
let sampleN = (t: t, n) =>
switch t {
| PointSet(r) => PointSetDist.sampleNRendered(n, r)
@ -31,6 +32,8 @@ let sampleN = (t: t, n) =>
| SampleSet(r) => SampleSetDist.sampleN(r, n)
}
let sample = (t: t) => sampleN(t, 1) -> E.A.first |> E.O.toExn("Should not have happened")
let toSampleSetDist = (t: t, n) =>
SampleSetDist.make(sampleN(t, n))->E.R2.errMap(DistributionTypes.Error.sampleErrorToDistErr)

View File

@ -6,6 +6,7 @@ type scaleMultiplyFn = (t, float) => result<t, error>
type pointwiseAddFn = (t, t) => result<t, error>
let sampleN: (t, int) => array<float>
let sample: t => float
let toSampleSetDist: (t, int) => Belt.Result.t<QuriSquiggleLang.SampleSetDist.t, error>

View File

@ -34,6 +34,7 @@ let rec matchInput = (input: itype, r: expressionValue): option<value> =>
switch (input, r) {
| (I_Number, EvNumber(f)) => Some(Number(f))
| (I_DistOrNumber, EvNumber(f)) => Some(DistOrNumber(Number(f)))
| (I_DistOrNumber, EvDistribution(Symbolic(#Float(f)))) => Some(DistOrNumber(Number(f)))
| (I_DistOrNumber, EvDistribution(f)) => Some(DistOrNumber(Dist(f)))
| (I_Numeric, EvNumber(f)) => Some(Number(f))
| (I_Numeric, EvDistribution(Symbolic(#Float(f)))) => Some(Number(f))
@ -236,17 +237,19 @@ module Registry = {
}
}
let impossibleError = "Wrong inputs / Logically impossible"
let twoNumberInputs = (inputs: array<value>) => {
switch inputs {
| [Number(n1), Number(n2)] => Ok(n1, n2)
| _ => Error("Wrong inputs / Logically impossible")
| _ => Error(impossibleError)
}
}
let twoNumberInputsRecord = (v1, v2, inputs: array<value>) =>
switch inputs {
| [Record([(name1, n1), (name2, n2)])] if name1 == v1 && name2 == v2 => twoNumberInputs([n1, n2])
| _ => Error("Wrong inputs / Logically impossible")
| _ => Error(impossibleError)
}
let contain = r => ReducerInterface_ExpressionValue.EvDistribution(Symbolic(r))
@ -258,24 +261,72 @@ let p5and95 = (p5, p95) => contain(SymbolicDist.Normal.from90PercentCI(p5, p95))
let convertTwoInputs = (inputs: array<value>): result<expressionValue, string> =>
twoNumberInputs(inputs)->E.R.bind(((mean, stdev)) => meanStdev(mean, stdev))
// let twoDistOrStdev = (a1:distOrNumber, a2:distOrNumber, fn) => {
// switch (a1, a2) {
// | (Number(a1), Number(a2)) => fn(a1, a2)
// | (Dist(a1), Number(a2)) => toSampleSetDist(a1, 1000)->sampleMap(r => fn(r, a2) |> sample)
// | (Number(a1), Dist(a2)) => toSampleSetDist(a2, 1000)->sampleMap(r => fn(a1, r) |> sample)
// | (Dist(a1), Dist(a2)) => SampleSetDist.map2(a1, a2, (m, s) => fn(m, s) |> sample)
// }
// }
let twoDistOrStdev = (a1: value, a2: value) => {
switch (a1, a2) {
| (DistOrNumber(a1), DistOrNumber(a2)) => Ok(a1, a2)
| _ => Error(impossibleError)
}
}
let distTwo = (
~fn: (float, float) => result<DistributionTypes.genericDist, string>,
a1: value,
a2: value,
) => {
let toSampleSet = r => GenericDist.toSampleSetDist(r, 1000)
let sampleSetToExpressionValue = (
b: Belt.Result.t<QuriSquiggleLang.SampleSetDist.t, QuriSquiggleLang.DistributionTypes.error>,
) =>
switch b {
| Ok(r) => Ok(ReducerInterface_ExpressionValue.EvDistribution(SampleSet(r)))
| Error(d) => Error(DistributionTypes.Error.toString(d))
}
let mapFnResult = r =>
switch r {
| Ok(r) => Ok(GenericDist.sample(r))
| Error(r) => Error(Operation.Other(r))
}
let singleVarSample = (a, fn) => {
let sampleSetResult =
toSampleSet(a) |> E.R2.bind(dist =>
SampleSetDist.samplesMap(
~fn=f => fn(f)->mapFnResult,
dist,
)->E.R2.errMap(r => DistributionTypes.SampleSetError(r))
)
sampleSetResult->sampleSetToExpressionValue
}
switch (a1, a2) {
| (DistOrNumber(Number(a1)), DistOrNumber(Number(a2))) =>
fn(a1, a2)->E.R2.fmap(r => ReducerInterface_ExpressionValue.EvDistribution(r))
| (DistOrNumber(Dist(a1)), DistOrNumber(Number(a2))) => singleVarSample(a1, r => fn(r, a2))
| (DistOrNumber(Number(a1)), DistOrNumber(Dist(a2))) => singleVarSample(a2, r => fn(a1, r))
| (DistOrNumber(Dist(a1)), DistOrNumber(Dist(a2))) => {
let altFn = (a, b) => fn(a, b)->mapFnResult
let sampleSetResult =
E.R.merge(toSampleSet(a1), toSampleSet(a2))
->E.R2.errMap(DistributionTypes.Error.toString)
->E.R.bind(((t1, t2)) => {
SampleSetDist.map2(~fn=altFn, ~t1, ~t2)->E.R2.errMap(Operation.Error.toString)
})
->E.R2.errMap(r => DistributionTypes.OtherError(r))
sampleSetResult->sampleSetToExpressionValue
}
| _ => Error(impossibleError)
}
}
let normal = Function.make(
"Normal",
[
Function.makeDefinition("normal", [I_Numeric, I_Numeric], inputs =>
twoNumberInputs(inputs)->E.R.bind(((mean, stdev)) => meanStdev(mean, stdev))
),
Function.makeDefinition("normal", [I_DistOrNumber, I_DistOrNumber], inputs =>
twoNumberInputs(inputs)->E.R.bind(((mean, stdev)) => meanStdev(mean, stdev))
),
Function.makeDefinition("normal", [I_DistOrNumber, I_DistOrNumber], inputs => {
let combine = (a1: float, a2: float) =>
SymbolicDist.Normal.make(a1, a2)->E.R2.fmap(r => DistributionTypes.Symbolic(r))
distTwo(~fn=combine, inputs[0], inputs[1])
}),
Function.makeDefinition(
"normal",
[I_Record([("mean", I_Numeric), ("stdev", I_Numeric)])],

View File

@ -58,6 +58,7 @@ type operationError =
| SampleMapNeedsNtoNFunction
| PdfInvalidError
| NotYetImplemented // should be removed when `klDivergence` for mixed and discrete is implemented.
| Other(string)
@genType
module Error = {
@ -73,6 +74,7 @@ module Error = {
| SampleMapNeedsNtoNFunction => "SampleMap needs a function that converts a number to a number"
| PdfInvalidError => "This Pdf is invalid"
| NotYetImplemented => "This pathway is not yet implemented"
| Other(t) => t
}
}