diff --git a/packages/squiggle-lang/src/rescript/sci.res b/packages/squiggle-lang/src/rescript/sci.res index dac5e5d6..15a57f44 100644 --- a/packages/squiggle-lang/src/rescript/sci.res +++ b/packages/squiggle-lang/src/rescript/sci.res @@ -1,19 +1,28 @@ -type symboliDist = SymbolicDistTypes.symbolicDist - type error = | NeedsPointSetConversion | InputsNeedPointSetConversion | NotYetImplemented + | ImpossiblePath | Other(string) type genericDist = [ - | #XYShape(PointSetTypes.pointSetDist) + | #PointSet(PointSetTypes.pointSetDist) | #SampleSet(array) - | #Symbolic(symboliDist) + | #Symbolic(SymbolicDistTypes.symbolicDist) +] + +type outputType = [ + | #Dist(genericDist) | #Error(error) | #Float(float) ] +let fromResult = (r: result): outputType => + switch r { + | Ok(o) => o + | Error(e) => #Error(e) + } + type direction = [ | #Algebraic | #Pointwise @@ -71,10 +80,11 @@ let genericParams = { } type wrapped = (genericDist, params) +type wrappedOutput = (outputType, params) let wrapWithParams = (g: genericDist, f: params): wrapped => (g, f) -let exampleDist: genericDist = #XYShape( +let exampleDist: genericDist = #PointSet( Discrete(Discrete.make(~integralSumCache=Some(1.0), {xs: [3.0], ys: [1.0]})), ) @@ -106,45 +116,79 @@ module AlgebraicCombination = { let sampleN = (n, genericDist) => { switch genericDist { - | #XYShape(r) => Ok(PointSetDist.sampleNRendered(n, r)) + | #PointSet(r) => Ok(PointSetDist.sampleNRendered(n, r)) | #Symbolic(r) => Ok(SymbolicDist.T.sampleN(n, r)) - | #SampleSet(r) => Error(NotYetImplemented) - | #Error(r) => Error(r) - | _ => Error(NotYetImplemented) + | #SampleSet(_) => Error(NotYetImplemented) } } -let rec applyFnInternal = (wrapped: wrapped, fnName: operation): wrapped => { - let (v, {sampleCount, xyPointLength} as extra) = wrapped - let newVal: genericDist = switch (fnName, v) { - | (#toFloat(n), #XYShape(r)) => #Float(PointSetDist.operate(n, r)) - | (#toFloat(n), #Symbolic(r)) => - switch SymbolicDist.T.operate(n, r) { - | Ok(float) => #Float(float) - | Error(e) => #Error(Other(e)) +let toFloat = ( + toPointSet: genericDist => result, + fnName, + value, +) => { + switch value { + | #Symbolic(r) if Belt.Result.isOk(SymbolicDist.T.operate(fnName, r)) => + switch SymbolicDist.T.operate(fnName, r) { + | Ok(float) => Ok(float) + | Error(_) => Error(ImpossiblePath) } - | (#toFloat(n), #SampleSet(_)) => #Error(NeedsPointSetConversion) - | (#toDist(#normalize), #XYShape(r)) => #XYShape(PointSetDist.T.normalize(r)) - | (#toDist(#normalize), #Symbolic(_)) => v - | (#toDist(#normalize), #SampleSet(_)) => v - | (#toDist(#toPointSet), #XYShape(_)) => v - | (#toDist(#toPointSet), #Symbolic(r)) => #XYShape(SymbolicDist.T.toPointSetDist(sampleCount, r)) - | (#toDist(#toPointSet), #SampleSet(r)) => { + | #PointSet(r) => Ok(PointSetDist.operate(fnName, r)) + | _ => + switch toPointSet(value) { + | Ok(r) => Ok(PointSetDist.operate(fnName, r)) + | Error(r) => Error(r) + } + } +} + +let distToPointSet = (sampleCount, dist: genericDist) => { + switch dist { + | #PointSet(pointSet) => Ok(pointSet) + | #Symbolic(r) => Ok(SymbolicDist.T.toPointSetDist(sampleCount, r)) + | #SampleSet(r) => { let response = SampleSet.toPointSetDist( ~samples=r, ~samplingInputs=defaultSamplingInputs, (), ).pointSetDist switch response { - | Some(r) => #XYShape(r) - | None => #Error(Other("Failed to convert sample into shape")) + | Some(r) => Ok(r) + | None => Error(Other("Converting sampleSet to pointSet failed")) } } - | (#toDist(#toSampleSet(n)), r) => - switch sampleN(n, r) { - | Ok(r) => #SampleSet(r) - | Error(r) => #Error(r) + } +} + +let rec applyFnInternal = (wrapped: wrapped, fnName: operation): wrappedOutput => { + let (value, {sampleCount, xyPointLength} as extra) = wrapped + let reCall = (~value=value, ~extra=extra, ~fnName=fnName, ()) => { + applyFnInternal((value, extra), fnName) + } + let reCallUnwrapped = (~value=value, ~extra=extra, ~fnName=fnName, ()) => { + let (value, _) = applyFnInternal((value, extra), fnName) + value + } + let toPointSet = r => { + switch reCallUnwrapped(~value=r, ~fnName=#toDist(#toPointSet), ()) { + | #Dist(#PointSet(p)) => Ok(p) + | #Error(r) => Error(r) + | _ => Error(Other("Impossible error")) } + } + let toPointSetAndReCall = v => + toPointSet(v) |> E.R.fmap(r => reCallUnwrapped(~value=#PointSet(r), ())) + let newVal: outputType = switch (fnName, value) { + // | (#toFloat(n), v) => toFloat(toPointSet, v, n) + | (#toFloat(fnName), _) => + toFloat(toPointSet, fnName, value) |> E.R.fmap(r => #Float(r)) |> fromResult + | (#toDist(#normalize), #PointSet(r)) => #Dist(#PointSet(PointSetDist.T.normalize(r))) + | (#toDist(#normalize), #Symbolic(_)) => #Dist(value) + | (#toDist(#normalize), #SampleSet(_)) => #Dist(value) + | (#toDist(#toPointSet), _) => + value |> distToPointSet(sampleCount) |> E.R.fmap(r => #Dist(#PointSet(r))) |> fromResult + | (#toDist(#toSampleSet(n)), _) => + value |> sampleN(n) |> E.R.fmap(r => #Dist(#SampleSet(r))) |> fromResult | (#toDistCombination(#Algebraic, operation, p2), p1) => { // TODO: This could be more complex, to get possible simplification and similar. let dist1 = sampleN(sampleCount, p1) @@ -153,43 +197,44 @@ let rec applyFnInternal = (wrapped: wrapped, fnName: operation): wrapped => { Belt.Array.zip(d1, d2) |> E.A.fmap(((a, b)) => Operation.Algebraic.toFn(operation, a, b)) }) switch samples { - | Ok(r) => #SampleSet(r) + | Ok(r) => #Dist(#SampleSet(r)) | Error(e) => #Error(e) } } | (#toDistCombination(#Pointwise, operation, p2), p1) => switch ( - applyFnInternal((p1, extra), #toDist(#toPointSet)), - applyFnInternal((p2, extra), #toDist(#toPointSet)), + toPointSet(p1), + toPointSet(p2) ) { - | ((#XYShape(p1), _), (#XYShape(p2), _)) => - #XYShape(PointSetDist.combinePointwise(combinationToFn(operation), p1, p2)) - | _ => #Error(Other("No Match or not supported")) + | (Ok(p1), Ok(p2)) => + // TODO: If the dist is symbolic, then it doesn't need to be converted into a pointSet + #Dist(#PointSet(PointSetDist.combinePointwise(combinationToFn(operation), p1, p2))) + | (_, _) => #Error(Other("No Match or not supported")) } | _ => #Error(Other("No Match or not supported")) } (newVal, {sampleCount: sampleCount, xyPointLength: xyPointLength}) } -let applyFn = (wrapped, fnName): wrapped => { - let (v, extra) as result = applyFnInternal(wrapped, fnName) - switch v { - | #Error(NeedsPointSetConversion) => { - let convertedToPointSet = applyFnInternal(wrapped, #toDist(#toPointSet)) - applyFnInternal(convertedToPointSet, fnName) - } - | #Error(InputsNeedPointSetConversion) => { - let altDist = switch fnName { - | #toDistCombination(p1, p2, dist) => { - let (newDist, _) = applyFnInternal((dist, extra), #toDist(#toPointSet)) - applyFnInternal(wrapped, #toDistCombination(p1, p2, newDist)) - } - | _ => (#Error(Other("Not needed")), extra) - } - altDist - } - | _ => result - } -} +// let applyFn = (wrapped, fnName): wrapped => { +// let (v, extra) as result = applyFnInternal(wrapped, fnName) +// switch v { +// | #Error(NeedsPointSetConversion) => { +// let convertedToPointSet = applyFnInternal(wrapped, #toDist(#toPointSet)) +// applyFnInternal(convertedToPointSet, fnName) +// } +// | #Error(InputsNeedPointSetConversion) => { +// let altDist = switch fnName { +// | #toDistCombination(p1, p2, dist) => { +// let (newDist, _) = applyFnInternal((dist, extra), #toDist(#toPointSet)) +// applyFnInternal(wrapped, #toDistCombination(p1, p2, newDist)) +// } +// | _ => (#Error(Other("Not needed")), extra) +// } +// altDist +// } +// | _ => result +// } +// } -let foo = exampleDist->wrapWithParams(genericParams)->applyFn(#toDist(#normalize)) +// let foo = exampleDist->wrapWithParams(genericParams)->applyFn(#toDist(#normalize))