diff --git a/packages/squiggle-lang/src/rescript/ProgramEvaluator.res b/packages/squiggle-lang/src/rescript/ProgramEvaluator.res index 6e2b118d..134daec3 100644 --- a/packages/squiggle-lang/src/rescript/ProgramEvaluator.res +++ b/packages/squiggle-lang/src/rescript/ProgramEvaluator.res @@ -64,7 +64,7 @@ module Internals = { } let makeOutputs = (graph, pointSetDist): outputs => {graph: graph, pointSetDist: pointSetDist} - let makeInputs = (inputs: Inputs.inputs): ASTTypes.AST.samplingInputs => { + let makeInputs = (inputs: Inputs.inputs): SamplingInputs.samplingInputs => { sampleCount: inputs.samplingInputs.sampleCount |> E.O.default(10000), outputXYPoints: inputs.samplingInputs.outputXYPoints |> E.O.default(10000), kernelWidth: inputs.samplingInputs.kernelWidth, @@ -74,7 +74,7 @@ module Internals = { let runNode = (inputs, node) => AST.toLeaf(makeInputs(inputs), inputs.environment, node) - let runProgram = (inputs: Inputs.inputs, p: ASTTypes.Program.program) => { + let runProgram = (inputs: Inputs.inputs, p: ASTTypes.AST.program) => { let ins = ref(inputs) p |> E.A.fmap(x => diff --git a/packages/squiggle-lang/src/rescript/interpreter/AST.res b/packages/squiggle-lang/src/rescript/interpreter/AST.res index a7f6619f..4b105f1d 100644 --- a/packages/squiggle-lang/src/rescript/interpreter/AST.res +++ b/packages/squiggle-lang/src/rescript/interpreter/AST.res @@ -18,7 +18,7 @@ let toPointSetDist = (samplingInputs, environment, node: node) => | Error(e) => Error(e) } -let runFunction = (samplingInputs, environment, inputs, fn: PTypes.Function.t) => { +let runFunction = (samplingInputs, environment, inputs, fn: ASTTypes.AST.Function.t) => { let params = envs(samplingInputs, environment) - PTypes.Function.run(params, inputs, fn) + ASTTypes.AST.Function.run(params, inputs, fn) } diff --git a/packages/squiggle-lang/src/rescript/interpreter/ASTEvaluator.res b/packages/squiggle-lang/src/rescript/interpreter/ASTEvaluator.res index 54cee830..481bd382 100644 --- a/packages/squiggle-lang/src/rescript/interpreter/ASTEvaluator.res +++ b/packages/squiggle-lang/src/rescript/interpreter/ASTEvaluator.res @@ -1,4 +1,3 @@ -open ASTTypes open ASTTypes.AST type t = node @@ -44,12 +43,17 @@ module AlgebraicCombination = { let combine = (evaluationParams, algebraicOp, t1: node, t2: node): result => E.R.merge( - PTypes.SamplingDistribution.renderIfIsNotSamplingDistribution(evaluationParams, t1), - PTypes.SamplingDistribution.renderIfIsNotSamplingDistribution(evaluationParams, t2), + ASTTypes.AST.SamplingDistribution.renderIfIsNotSamplingDistribution(evaluationParams, t1), + ASTTypes.AST.SamplingDistribution.renderIfIsNotSamplingDistribution(evaluationParams, t2), ) |> E.R.bind(_, ((a, b)) => switch choose(a, b) { | #Sampling => - PTypes.SamplingDistribution.combineShapesUsingSampling(evaluationParams, algebraicOp, a, b) + ASTTypes.AST.SamplingDistribution.combineShapesUsingSampling( + evaluationParams, + algebraicOp, + a, + b, + ) | #Analytical => combinationByRendering(evaluationParams, algebraicOp, a, b) } ) @@ -118,6 +122,12 @@ module PointwiseCombination = { } module Truncate = { + type simplificationResult = [ + | #Solution(ASTTypes.AST.node) + | #Error(string) + | #NoSolution + ] + let trySimplification = (leftCutoff, rightCutoff, t): simplificationResult => switch (leftCutoff, rightCutoff, t) { | (None, None, t) => #Solution(t) @@ -132,7 +142,8 @@ module Truncate = { switch // TODO: use named args for xMin/xMax in renderToShape; if we're lucky we can at least get the tail // of a distribution we otherwise wouldn't get at all Node.ensureIsRendered(evaluationParams, t) { - | Ok(#RenderedDist(rs)) => Ok(#RenderedDist(PointSetDist.T.truncate(leftCutoff, rightCutoff, rs))) + | Ok(#RenderedDist(rs)) => + Ok(#RenderedDist(PointSetDist.T.truncate(leftCutoff, rightCutoff, rs))) | Error(e) => Error(e) | _ => Error("Could not truncate distribution.") } @@ -160,7 +171,7 @@ module Normalize = { switch t { | #RenderedDist(s) => Ok(#RenderedDist(PointSetDist.T.normalize(s))) | #SymbolicDist(_) => Ok(t) - | _ => evaluateAndRetry(evaluationParams, operationToLeaf, t) + | _ => ASTTypes.AST.Node.evaluateAndRetry(evaluationParams, operationToLeaf, t) } } @@ -170,7 +181,7 @@ module FunctionCall = { let _runLocalFunction = (name, evaluationParams: evaluationParams, args) => Environment.getFunction(evaluationParams.environment, name) |> E.R.bind(_, ((argNames, fn)) => - PTypes.Function.run(evaluationParams, args, (argNames, fn)) + ASTTypes.AST.Function.run(evaluationParams, args, (argNames, fn)) ) let _runWithEvaluatedInputs = ( @@ -195,9 +206,13 @@ module Render = { switch t { | #Function(_) => Error("Cannot render a function") | #SymbolicDist(d) => - Ok(#RenderedDist(SymbolicDist.T.toPointSetDist(evaluationParams.samplingInputs.pointSetDistLength, d))) + Ok( + #RenderedDist( + SymbolicDist.T.toPointSetDist(evaluationParams.samplingInputs.pointSetDistLength, d), + ), + ) | #RenderedDist(_) as t => Ok(t) // already a rendered pointSetDist, we're done here - | _ => evaluateAndRetry(evaluationParams, operationToLeaf, t) + | _ => ASTTypes.AST.Node.evaluateAndRetry(evaluationParams, operationToLeaf, t) } } @@ -207,10 +222,7 @@ module Render = { but most often it will produce a RenderedDist. This function is used mainly to turn a parse tree into a single RenderedDist that can then be displayed to the user. */ -let rec toLeaf = ( - evaluationParams: ASTTypes.AST.evaluationParams, - node: t, -): result => +let rec toLeaf = (evaluationParams: ASTTypes.AST.evaluationParams, node: t): result => switch node { // Leaf nodes just stay leaf nodes | #SymbolicDist(_) diff --git a/packages/squiggle-lang/src/rescript/interpreter/ASTTypes.res b/packages/squiggle-lang/src/rescript/interpreter/ASTTypes.res index 85b5ac12..8526e9ec 100644 --- a/packages/squiggle-lang/src/rescript/interpreter/ASTTypes.res +++ b/packages/squiggle-lang/src/rescript/interpreter/ASTTypes.res @@ -15,43 +15,20 @@ module AST = { | #FunctionCall(string, array) ] - module Hash = { - type t<'a> = array<(string, 'a)> - let getByName = (t: t<'a>, name) => - E.A.getBy(t, ((n, _)) => n == name) |> E.O.fmap(((_, r)) => r) - - let getByNameResult = (t: t<'a>, name) => - getByName(t, name) |> E.O.toResult(name ++ " expected and not found") - - let getByNames = (hash: t<'a>, names: array) => - names |> E.A.fmap(name => (name, getByName(hash, name))) - } - // Have nil as option - - type samplingInputs = { - sampleCount: int, - outputXYPoints: int, - kernelWidth: option, - pointSetDistLength: int, - } - - module SamplingInputs = { - type t = { - sampleCount: option, - outputXYPoints: option, - kernelWidth: option, - pointSetDistLength: option, - } - let withDefaults = (t: t): samplingInputs => { - sampleCount: t.sampleCount |> E.O.default(10000), - outputXYPoints: t.outputXYPoints |> E.O.default(10000), - kernelWidth: t.kernelWidth, - pointSetDistLength: t.pointSetDistLength |> E.O.default(10000), - } - } + type statement = [ + | #Assignment(string, node) + | #Expression(node) + ] + type program = array type environment = Belt.Map.String.t + type rec evaluationParams = { + samplingInputs: SamplingInputs.samplingInputs, + environment: environment, + evaluateNode: (evaluationParams, node) => Belt.Result.t, + } + module Environment = { type t = environment module MS = Belt.Map.String @@ -74,18 +51,6 @@ module AST = { } } - type rec evaluationParams = { - samplingInputs: samplingInputs, - environment: environment, - evaluateNode: (evaluationParams, node) => Belt.Result.t, - } - - let evaluateNode = (evaluationParams: evaluationParams) => - evaluationParams.evaluateNode(evaluationParams) - - let evaluateAndRetry = (evaluationParams, fn, node) => - node |> evaluationParams.evaluateNode(evaluationParams) |> E.R.bind(_, fn(evaluationParams)) - module Node = { let getFloat = (node: node) => node |> ( @@ -97,6 +62,12 @@ module AST = { } ) + let evaluate = (evaluationParams: evaluationParams) => + evaluationParams.evaluateNode(evaluationParams) + + let evaluateAndRetry = (evaluationParams, fn, node) => + node |> evaluationParams.evaluateNode(evaluationParams) |> E.R.bind(_, fn(evaluationParams)) + let rec toString: node => string = x => switch x { | #SymbolicDist(d) => SymbolicDist.T.toString(d) @@ -124,8 +95,7 @@ module AST = { "}") } - let render = (evaluationParams: evaluationParams, r) => - #Render(r) |> evaluateNode(evaluationParams) + let render = (evaluationParams: evaluationParams, r) => #Render(r) |> evaluate(evaluationParams) let ensureIsRendered = (params, t) => switch t { @@ -160,18 +130,143 @@ module AST = { let toFloat = (item: node): result => item |> toPointSetDist |> E.O.bind(_, _toFloat) |> E.O.toResult("Not valid shape") } -} -type simplificationResult = [ - | #Solution(AST.node) - | #Error(string) - | #NoSolution -] + module Function = { + type t = (array, node) + let fromNode: node => option = node => + switch node { + | #Function(r) => Some(r) + | _ => None + } + let argumentNames = ((a, _): t) => a + let internals = ((_, b): t) => b + let run = (evaluationParams: evaluationParams, args: array, t: t) => + if E.A.length(args) == E.A.length(argumentNames(t)) { + let newEnvironment = Belt.Array.zip(argumentNames(t), args) |> Environment.fromArray + let newEvaluationParams: evaluationParams = { + samplingInputs: evaluationParams.samplingInputs, + environment: Environment.mergeKeepSecond( + evaluationParams.environment, + newEnvironment, + ), + evaluateNode: evaluationParams.evaluateNode, + } + evaluationParams.evaluateNode(newEvaluationParams, internals(t)) + } else { + Error("Wrong number of variables") + } + } -module Program = { - type statement = [ - | #Assignment(string, AST.node) - | #Expression(AST.node) - ] - type program = array + module Primative = { + type t = [ + | #SymbolicDist(SymbolicDistTypes.symbolicDist) + | #RenderedDist(PointSetTypes.pointSetDist) + | #Function(array, node) + ] + + let isPrimative: node => bool = x => + switch x { + | #SymbolicDist(_) + | #RenderedDist(_) + | #Function(_) => true + | _ => false + } + + let fromNode: node => option = x => + switch x { + | #SymbolicDist(_) as n + | #RenderedDist(_) as n + | #Function(_) as n => + Some(n) + | _ => None + } + } + + module SamplingDistribution = { + type t = [ + | #SymbolicDist(SymbolicDistTypes.symbolicDist) + | #RenderedDist(PointSetTypes.pointSetDist) + ] + + let isSamplingDistribution: node => bool = x => + switch x { + | #SymbolicDist(_) => true + | #RenderedDist(_) => true + | _ => false + } + + let fromNode: node => result = x => + switch x { + | #SymbolicDist(n) => Ok(#SymbolicDist(n)) + | #RenderedDist(n) => Ok(#RenderedDist(n)) + | _ => Error("Not valid type") + } + + let renderIfIsNotSamplingDistribution = (params, t): result => + !isSamplingDistribution(t) + ? switch Node.render(params, t) { + | Ok(r) => Ok(r) + | Error(e) => Error(e) + } + : Ok(t) + + let map = (~renderedDistFn, ~symbolicDistFn, node: node) => + node |> ( + x => + switch x { + | #RenderedDist(r) => Some(renderedDistFn(r)) + | #SymbolicDist(s) => Some(symbolicDistFn(s)) + | _ => None + } + ) + + let sampleN = n => + map( + ~renderedDistFn=PointSetDist.sampleNRendered(n), + ~symbolicDistFn=SymbolicDist.T.sampleN(n), + ) + + let getCombinationSamples = (n, algebraicOp, t1: node, t2: node) => + switch (sampleN(n, t1), sampleN(n, t2)) { + | (Some(a), Some(b)) => + Some( + Belt.Array.zip(a, b) |> E.A.fmap(((a, b)) => Operation.Algebraic.toFn(algebraicOp, a, b)), + ) + | _ => None + } + + let combineShapesUsingSampling = ( + evaluationParams: evaluationParams, + algebraicOp, + t1: node, + t2: node, + ) => { + let i1 = renderIfIsNotSamplingDistribution(evaluationParams, t1) + let i2 = renderIfIsNotSamplingDistribution(evaluationParams, t2) + E.R.merge(i1, i2) |> E.R.bind(_, ((a, b)) => { + let samples = getCombinationSamples( + evaluationParams.samplingInputs.sampleCount, + algebraicOp, + a, + b, + ) + + let pointSetDist = + samples + |> E.O.fmap(r => + SampleSet.toPointSetDist( + ~samplingInputs=evaluationParams.samplingInputs, + ~samples=r, + (), + ) + ) + |> E.O.bind(_, r => r.pointSetDist) + |> E.O.toResult("No response") + pointSetDist |> E.R.fmap(r => #Normalize(#RenderedDist(r))) + }) + + // todo: This bottom part should probably be somewhere else. + // todo: REFACTOR: I'm not sure about the SampleSet line. + } + } } diff --git a/packages/squiggle-lang/src/rescript/interpreter/PTypes.res b/packages/squiggle-lang/src/rescript/interpreter/PTypes.res deleted file mode 100644 index bc9ac14a..00000000 --- a/packages/squiggle-lang/src/rescript/interpreter/PTypes.res +++ /dev/null @@ -1,138 +0,0 @@ -open ASTTypes.AST - -module Function = { - type t = (array, node) - let fromNode: node => option = node => - switch node { - | #Function(r) => Some(r) - | _ => None - } - let argumentNames = ((a, _): t) => a - let internals = ((_, b): t) => b - let run = ( - evaluationParams: ASTTypes.AST.evaluationParams, - args: array, - t: t, - ) => - if E.A.length(args) == E.A.length(argumentNames(t)) { - let newEnvironment = - Belt.Array.zip( - argumentNames(t), - args, - ) |> ASTTypes.AST.Environment.fromArray - let newEvaluationParams: ASTTypes.AST.evaluationParams = { - samplingInputs: evaluationParams.samplingInputs, - environment: ASTTypes.AST.Environment.mergeKeepSecond( - evaluationParams.environment, - newEnvironment, - ), - evaluateNode: evaluationParams.evaluateNode, - } - evaluationParams.evaluateNode(newEvaluationParams, internals(t)) - } else { - Error("Wrong number of variables") - } -} - -module Primative = { - type t = [ - | #SymbolicDist(SymbolicDistTypes.symbolicDist) - | #RenderedDist(PointSetTypes.pointSetDist) - | #Function(array, node) - ] - - let isPrimative: node => bool = x => - switch x { - | #SymbolicDist(_) - | #RenderedDist(_) - | #Function(_) => true - | _ => false - } - - let fromNode: node => option = x => - switch x { - | #SymbolicDist(_) as n - | #RenderedDist(_) as n - | #Function(_) as n => - Some(n) - | _ => None - } -} - -module SamplingDistribution = { - type t = [ - | #SymbolicDist(SymbolicDistTypes.symbolicDist) - | #RenderedDist(PointSetTypes.pointSetDist) - ] - - let isSamplingDistribution: node => bool = x => - switch x { - | #SymbolicDist(_) => true - | #RenderedDist(_) => true - | _ => false - } - - let fromNode: node => result = x => - switch x { - | #SymbolicDist(n) => Ok(#SymbolicDist(n)) - | #RenderedDist(n) => Ok(#RenderedDist(n)) - | _ => Error("Not valid type") - } - - let renderIfIsNotSamplingDistribution = (params, t): result => - !isSamplingDistribution(t) - ? switch Node.render(params, t) { - | Ok(r) => Ok(r) - | Error(e) => Error(e) - } - : Ok(t) - - let map = (~renderedDistFn, ~symbolicDistFn, node: node) => - node |> ( - x => - switch x { - | #RenderedDist(r) => Some(renderedDistFn(r)) - | #SymbolicDist(s) => Some(symbolicDistFn(s)) - | _ => None - } - ) - - let sampleN = n => - map(~renderedDistFn=PointSetDist.sampleNRendered(n), ~symbolicDistFn=SymbolicDist.T.sampleN(n)) - - let getCombinationSamples = (n, algebraicOp, t1: node, t2: node) => - switch (sampleN(n, t1), sampleN(n, t2)) { - | (Some(a), Some(b)) => - Some( - Belt.Array.zip(a, b) |> E.A.fmap(((a, b)) => Operation.Algebraic.toFn(algebraicOp, a, b)), - ) - | _ => None - } - - let combineShapesUsingSampling = ( - evaluationParams: evaluationParams, - algebraicOp, - t1: node, - t2: node, - ) => { - let i1 = renderIfIsNotSamplingDistribution(evaluationParams, t1) - let i2 = renderIfIsNotSamplingDistribution(evaluationParams, t2) - E.R.merge(i1, i2) |> E.R.bind(_, ((a, b)) => { - let samples = getCombinationSamples( - evaluationParams.samplingInputs.sampleCount, - algebraicOp, - a, - b, - ) - - // todo: This bottom part should probably be somewhere else. - // todo: REFACTOR: I'm not sure about the SampleSet line. - let pointSetDist = - samples - |> E.O.fmap(r => SampleSet.toPointSetDist(~samplingInputs=evaluationParams.samplingInputs, ~samples=r, ())) - |> E.O.bind(_, r => r.pointSetDist) - |> E.O.toResult("No response") - pointSetDist |> E.R.fmap(r => #Normalize(#RenderedDist(r))) - }) - } -} diff --git a/packages/squiggle-lang/src/rescript/interpreter/typeSystem/HardcodedFunctions.res b/packages/squiggle-lang/src/rescript/interpreter/typeSystem/HardcodedFunctions.res index 5bf93c4d..2ef80d5d 100644 --- a/packages/squiggle-lang/src/rescript/interpreter/typeSystem/HardcodedFunctions.res +++ b/packages/squiggle-lang/src/rescript/interpreter/typeSystem/HardcodedFunctions.res @@ -111,7 +111,7 @@ let verticalScaling = (scaleOp, rs, scaleBy) => { } module Multimodal = { - let getByNameResult = ASTTypes.AST.Hash.getByNameResult + let getByNameResult = Hash.getByNameResult let _paramsToDistsAndWeights = (r: array) => switch r { diff --git a/packages/squiggle-lang/src/rescript/interpreter/typeSystem/TypeSystem.res b/packages/squiggle-lang/src/rescript/interpreter/typeSystem/TypeSystem.res index 95321dee..16f6e011 100644 --- a/packages/squiggle-lang/src/rescript/interpreter/typeSystem/TypeSystem.res +++ b/packages/squiggle-lang/src/rescript/interpreter/typeSystem/TypeSystem.res @@ -73,7 +73,7 @@ module TypedValue = { | _ => Error("Type Error: Expected float.") } | (#SamplingDistribution, _) => - PTypes.SamplingDistribution.renderIfIsNotSamplingDistribution( + ASTTypes.AST.SamplingDistribution.renderIfIsNotSamplingDistribution( evaluationParams, node, ) |> E.R.bind(_, fromNode) @@ -89,7 +89,7 @@ module TypedValue = { named |> E.A.fmap(((name, intendedType)) => ( name, intendedType, - ASTTypes.AST.Hash.getByName(r, name), + Hash.getByName(r, name), )) let typedHash = keyValues diff --git a/packages/squiggle-lang/src/rescript/parser/Parser.res b/packages/squiggle-lang/src/rescript/parser/Parser.res index 5f310d19..aebfa1de 100644 --- a/packages/squiggle-lang/src/rescript/parser/Parser.res +++ b/packages/squiggle-lang/src/rescript/parser/Parser.res @@ -246,7 +246,7 @@ module MathAdtToDistDst = { // let evaluatedExpression = run(expression); // `Function(_ => Ok(evaluatedExpression)); // } - let rec topLevel = (r): result => + let rec topLevel = (r): result => switch r { | FunctionAssignment({name, args, expression}) => switch nodeParser(expression) { @@ -267,7 +267,7 @@ module MathAdtToDistDst = { blocks |> E.A.fmap(b => topLevel(b)) |> E.A.R.firstErrorOrOpen |> E.R.fmap(E.A.concatMany) } - let run = (r): result => + let run = (r): result => r |> MathAdtCleaner.run |> topLevel } diff --git a/packages/squiggle-lang/src/rescript/sampleSet/SampleSet.res b/packages/squiggle-lang/src/rescript/sampleSet/SampleSet.res index 03227ed2..6b20d946 100644 --- a/packages/squiggle-lang/src/rescript/sampleSet/SampleSet.res +++ b/packages/squiggle-lang/src/rescript/sampleSet/SampleSet.res @@ -80,7 +80,7 @@ module Internals = { let toPointSetDist = ( ~samples: Internals.T.t, - ~samplingInputs: ASTTypes.AST.samplingInputs, + ~samplingInputs: SamplingInputs.samplingInputs, (), ) => { Array.fast_sort(compare, samples) diff --git a/packages/squiggle-lang/src/rescript/utility/Hash.res b/packages/squiggle-lang/src/rescript/utility/Hash.res new file mode 100644 index 00000000..14e78a3f --- /dev/null +++ b/packages/squiggle-lang/src/rescript/utility/Hash.res @@ -0,0 +1,8 @@ +type t<'a> = array<(string, 'a)> +let getByName = (t: t<'a>, name) => E.A.getBy(t, ((n, _)) => n == name) |> E.O.fmap(((_, r)) => r) + +let getByNameResult = (t: t<'a>, name) => + getByName(t, name) |> E.O.toResult(name ++ " expected and not found") + +let getByNames = (hash: t<'a>, names: array) => + names |> E.A.fmap(name => (name, getByName(hash, name))) diff --git a/packages/squiggle-lang/src/rescript/utility/SamplingInputs.res b/packages/squiggle-lang/src/rescript/utility/SamplingInputs.res new file mode 100644 index 00000000..20ed3cd9 --- /dev/null +++ b/packages/squiggle-lang/src/rescript/utility/SamplingInputs.res @@ -0,0 +1,21 @@ +type samplingInputs = { + sampleCount: int, + outputXYPoints: int, + kernelWidth: option, + pointSetDistLength: int, +} + +module SamplingInputs = { + type t = { + sampleCount: option, + outputXYPoints: option, + kernelWidth: option, + pointSetDistLength: option, + } + let withDefaults = (t: t): samplingInputs => { + sampleCount: t.sampleCount |> E.O.default(10000), + outputXYPoints: t.outputXYPoints |> E.O.default(10000), + kernelWidth: t.kernelWidth, + pointSetDistLength: t.pointSetDistLength |> E.O.default(10000), + } +}