diff --git a/packages/squiggle-lang/src/rescript/Distributions/GenericDist.res b/packages/squiggle-lang/src/rescript/Distributions/GenericDist.res index 1df10240..dc50833e 100644 --- a/packages/squiggle-lang/src/rescript/Distributions/GenericDist.res +++ b/packages/squiggle-lang/src/rescript/Distributions/GenericDist.res @@ -24,6 +24,7 @@ let isSymbolic = (t: t) => | _ => false } + let sampleN = (t: t, n) => switch t { | PointSet(r) => PointSetDist.sampleNRendered(n, r) @@ -31,6 +32,8 @@ let sampleN = (t: t, n) => | SampleSet(r) => SampleSetDist.sampleN(r, n) } +let sample = (t: t) => sampleN(t, 1) -> E.A.first |> E.O.toExn("Should not have happened") + let toSampleSetDist = (t: t, n) => SampleSetDist.make(sampleN(t, n))->E.R2.errMap(DistributionTypes.Error.sampleErrorToDistErr) diff --git a/packages/squiggle-lang/src/rescript/Distributions/GenericDist.resi b/packages/squiggle-lang/src/rescript/Distributions/GenericDist.resi index 79fb54ab..fd9afa58 100644 --- a/packages/squiggle-lang/src/rescript/Distributions/GenericDist.resi +++ b/packages/squiggle-lang/src/rescript/Distributions/GenericDist.resi @@ -6,6 +6,7 @@ type scaleMultiplyFn = (t, float) => result type pointwiseAddFn = (t, t) => result let sampleN: (t, int) => array +let sample: t => float let toSampleSetDist: (t, int) => Belt.Result.t diff --git a/packages/squiggle-lang/src/rescript/FunctionRegistry.res b/packages/squiggle-lang/src/rescript/FunctionRegistry.res index 7a74ed0c..df04cef9 100644 --- a/packages/squiggle-lang/src/rescript/FunctionRegistry.res +++ b/packages/squiggle-lang/src/rescript/FunctionRegistry.res @@ -34,6 +34,7 @@ let rec matchInput = (input: itype, r: expressionValue): option => switch (input, r) { | (I_Number, EvNumber(f)) => Some(Number(f)) | (I_DistOrNumber, EvNumber(f)) => Some(DistOrNumber(Number(f))) + | (I_DistOrNumber, EvDistribution(Symbolic(#Float(f)))) => Some(DistOrNumber(Number(f))) | (I_DistOrNumber, EvDistribution(f)) => Some(DistOrNumber(Dist(f))) | (I_Numeric, EvNumber(f)) => Some(Number(f)) | (I_Numeric, EvDistribution(Symbolic(#Float(f)))) => Some(Number(f)) @@ -236,17 +237,19 @@ module Registry = { } } +let impossibleError = "Wrong inputs / Logically impossible" + let twoNumberInputs = (inputs: array) => { switch inputs { | [Number(n1), Number(n2)] => Ok(n1, n2) - | _ => Error("Wrong inputs / Logically impossible") + | _ => Error(impossibleError) } } let twoNumberInputsRecord = (v1, v2, inputs: array) => switch inputs { | [Record([(name1, n1), (name2, n2)])] if name1 == v1 && name2 == v2 => twoNumberInputs([n1, n2]) - | _ => Error("Wrong inputs / Logically impossible") + | _ => Error(impossibleError) } let contain = r => ReducerInterface_ExpressionValue.EvDistribution(Symbolic(r)) @@ -258,24 +261,72 @@ let p5and95 = (p5, p95) => contain(SymbolicDist.Normal.from90PercentCI(p5, p95)) let convertTwoInputs = (inputs: array): result => twoNumberInputs(inputs)->E.R.bind(((mean, stdev)) => meanStdev(mean, stdev)) -// let twoDistOrStdev = (a1:distOrNumber, a2:distOrNumber, fn) => { -// switch (a1, a2) { -// | (Number(a1), Number(a2)) => fn(a1, a2) -// | (Dist(a1), Number(a2)) => toSampleSetDist(a1, 1000)->sampleMap(r => fn(r, a2) |> sample) -// | (Number(a1), Dist(a2)) => toSampleSetDist(a2, 1000)->sampleMap(r => fn(a1, r) |> sample) -// | (Dist(a1), Dist(a2)) => SampleSetDist.map2(a1, a2, (m, s) => fn(m, s) |> sample) -// } -// } +let twoDistOrStdev = (a1: value, a2: value) => { + switch (a1, a2) { + | (DistOrNumber(a1), DistOrNumber(a2)) => Ok(a1, a2) + | _ => Error(impossibleError) + } +} + +let distTwo = ( + ~fn: (float, float) => result, + a1: value, + a2: value, +) => { + let toSampleSet = r => GenericDist.toSampleSetDist(r, 1000) + let sampleSetToExpressionValue = ( + b: Belt.Result.t, + ) => + switch b { + | Ok(r) => Ok(ReducerInterface_ExpressionValue.EvDistribution(SampleSet(r))) + | Error(d) => Error(DistributionTypes.Error.toString(d)) + } + + let mapFnResult = r => + switch r { + | Ok(r) => Ok(GenericDist.sample(r)) + | Error(r) => Error(Operation.Other(r)) + } + + let singleVarSample = (a, fn) => { + let sampleSetResult = + toSampleSet(a) |> E.R2.bind(dist => + SampleSetDist.samplesMap( + ~fn=f => fn(f)->mapFnResult, + dist, + )->E.R2.errMap(r => DistributionTypes.SampleSetError(r)) + ) + sampleSetResult->sampleSetToExpressionValue + } + + switch (a1, a2) { + | (DistOrNumber(Number(a1)), DistOrNumber(Number(a2))) => + fn(a1, a2)->E.R2.fmap(r => ReducerInterface_ExpressionValue.EvDistribution(r)) + | (DistOrNumber(Dist(a1)), DistOrNumber(Number(a2))) => singleVarSample(a1, r => fn(r, a2)) + | (DistOrNumber(Number(a1)), DistOrNumber(Dist(a2))) => singleVarSample(a2, r => fn(a1, r)) + | (DistOrNumber(Dist(a1)), DistOrNumber(Dist(a2))) => { + let altFn = (a, b) => fn(a, b)->mapFnResult + let sampleSetResult = + E.R.merge(toSampleSet(a1), toSampleSet(a2)) + ->E.R2.errMap(DistributionTypes.Error.toString) + ->E.R.bind(((t1, t2)) => { + SampleSetDist.map2(~fn=altFn, ~t1, ~t2)->E.R2.errMap(Operation.Error.toString) + }) + ->E.R2.errMap(r => DistributionTypes.OtherError(r)) + sampleSetResult->sampleSetToExpressionValue + } + | _ => Error(impossibleError) + } +} let normal = Function.make( "Normal", [ - Function.makeDefinition("normal", [I_Numeric, I_Numeric], inputs => - twoNumberInputs(inputs)->E.R.bind(((mean, stdev)) => meanStdev(mean, stdev)) - ), - Function.makeDefinition("normal", [I_DistOrNumber, I_DistOrNumber], inputs => - twoNumberInputs(inputs)->E.R.bind(((mean, stdev)) => meanStdev(mean, stdev)) - ), + Function.makeDefinition("normal", [I_DistOrNumber, I_DistOrNumber], inputs => { + let combine = (a1: float, a2: float) => + SymbolicDist.Normal.make(a1, a2)->E.R2.fmap(r => DistributionTypes.Symbolic(r)) + distTwo(~fn=combine, inputs[0], inputs[1]) + }), Function.makeDefinition( "normal", [I_Record([("mean", I_Numeric), ("stdev", I_Numeric)])], diff --git a/packages/squiggle-lang/src/rescript/Utility/Operation.res b/packages/squiggle-lang/src/rescript/Utility/Operation.res index 7972b2fa..6476850d 100644 --- a/packages/squiggle-lang/src/rescript/Utility/Operation.res +++ b/packages/squiggle-lang/src/rescript/Utility/Operation.res @@ -58,6 +58,7 @@ type operationError = | SampleMapNeedsNtoNFunction | PdfInvalidError | NotYetImplemented // should be removed when `klDivergence` for mixed and discrete is implemented. + | Other(string) @genType module Error = { @@ -73,6 +74,7 @@ module Error = { | SampleMapNeedsNtoNFunction => "SampleMap needs a function that converts a number to a number" | PdfInvalidError => "This Pdf is invalid" | NotYetImplemented => "This pathway is not yet implemented" + | Other(t) => t } }