From 540d035b900db77b2a8db30da5bd8a6ec5bd6f04 Mon Sep 17 00:00:00 2001 From: Ozzie Gooen Date: Wed, 23 Mar 2022 17:29:20 -0400 Subject: [PATCH] Refactored applyFnInternal --- packages/squiggle-lang/src/rescript/sci.res | 92 ++++++++++++++------- 1 file changed, 61 insertions(+), 31 deletions(-) diff --git a/packages/squiggle-lang/src/rescript/sci.res b/packages/squiggle-lang/src/rescript/sci.res index c31c22f8..82b33e5f 100644 --- a/packages/squiggle-lang/src/rescript/sci.res +++ b/packages/squiggle-lang/src/rescript/sci.res @@ -1,13 +1,44 @@ type symboliDist = SymbolicDistTypes.symbolicDist +type error = + | NeedsPointSetConversion + | Other(string) + type genericDist = [ | #XYShape(PointSetTypes.pointSetDist) | #SampleSet(array) | #Symbolic(symboliDist) - | #Error(string) + | #Error(error) | #Float(float) ] +type combination = [ + | #Add + | #Multiply + | #Subtract + | #Divide + | #Exponentiate +] + +type toFloat = [ + | #Cdf(float) + | #Inv(float) + | #Mean + | #Pdf(float) + | #Sample +] + +type toDist = [ + | #normalize + | #toPointSet +] + +type operation = [ + | #toFloat(toFloat) + | #toDist(toDist) + | #toDistCombination(combination, genericDist) +] + type params = { sampleCount: int, xyPointLength: int, @@ -33,28 +64,22 @@ let defaultSamplingInputs: SamplingInputs.samplingInputs = { kernelWidth: None, } -let distToFloat = (wrapped: wrapped, fnName) => { +let applyFnInternal = (wrapped: wrapped, fnName: operation): wrapped => { let (v, extra) = wrapped - let newVal = switch (fnName, v) { - | (operation, #XYShape(r)) => #Float(PointSetDist.operate(operation, r)) - | (operation, #Symbolic(r)) => switch(SymbolicDist.T.operate(operation, r)){ - | Ok(r) => #SymbolicDist(r) - | Error(r) => #Error(r) - } - | _ => #Error("No Match") - } - (newVal, extra) -} - -let distToDist = (wrapped: wrapped, fnName): wrapped => { - let (v, extra) = wrapped - let newVal = switch (fnName, v) { - | (#normalize, #XYShape(r)) => #XYShape(PointSetDist.T.normalize(r)) - | (#normalize, #Symbolic(_)) => v - | (#normalize, #SampleSet(_)) => v - | (#toPointSet, #XYShape(_)) => v - | (#toPointSet, #Symbolic(r)) => #XYShape(SymbolicDist.T.toPointSetDist(1000, r)) - | (#toPointSet, #SampleSet(r)) => { + let newVal: genericDist = switch (fnName, v) { + | (#toFloat(n), #XYShape(r)) => #Float(PointSetDist.operate(n, r)) + | (#toFloat(n), #Symbolic(r)) => + switch SymbolicDist.T.operate(n, r) { + | Ok(float) => #Float(float) + | Error(e) => #Error(Other(e)) + } + | (#toFloat(n), #SampleSet(_)) => #Error(NeedsPointSetConversion) + | (#toDist(#normalize), #XYShape(r)) => #XYShape(PointSetDist.T.normalize(r)) + | (#toDist(#normalize), #Symbolic(_)) => v + | (#toDist(#normalize), #SampleSet(_)) => v + | (#toDist(#toPointSet), #XYShape(_)) => v + | (#toDist(#toPointSet), #Symbolic(r)) => #XYShape(SymbolicDist.T.toPointSetDist(1000, r)) + | (#toDist(#toPointSet), #SampleSet(r)) => { let response = SampleSet.toPointSetDist( ~samples=r, ~samplingInputs=defaultSamplingInputs, @@ -62,18 +87,23 @@ let distToDist = (wrapped: wrapped, fnName): wrapped => { ).pointSetDist switch response { | Some(r) => #XYShape(r) - | None => #Error("Failed to convert sample into shape") + | None => #Error(Other("Failed to convert sample into shape")) } } - | _ => #Error("No Match") + | _ => #Error(Other("No Match or not supported")) } (newVal, extra) } -// | (#truncateLeft(f), #XYContinuous(r)) => #XYContinuous(Continuous.T.truncate(Some(f), None, r)) -// | (#truncateRight(f), #XYContinuous(r)) => #XYContinuous(Continuous.T.truncate(None, Some(f), r)) -let foo = - exampleDist - ->wrapWithParams(genericParams) - ->distToDist(#truncateLeft(3.0)) - ->distToDist(#trunctateRight(5.0)) +let applyFn = (wrapped, fnName): wrapped => { + let (v, extra) as result = applyFnInternal(wrapped, fnName) + switch v { + | #Error(NeedsPointSetConversion) => { + let convertedToPointSet = applyFnInternal(wrapped, #toDist(#toPointSet)) + applyFnInternal(convertedToPointSet, fnName) + } + | _ => result + } +} + +let foo = exampleDist->wrapWithParams(genericParams)->applyFn(#toDist(#normalize))