diff --git a/packages/squiggle-lang/__tests__/Hardcoded__Test.res b/packages/squiggle-lang/__tests__/Hardcoded__Test.res index bad6a5e3..760c0010 100644 --- a/packages/squiggle-lang/__tests__/Hardcoded__Test.res +++ b/packages/squiggle-lang/__tests__/Hardcoded__Test.res @@ -10,7 +10,7 @@ let makeTest = (~only=false, str, item1, item2) => expect(item1) |> toEqual(item2) ); -let evalParams: ASTTypes.AST.evaluationParams = { +let evalParams: ASTTypes.evaluationParams = { samplingInputs: { sampleCount: 1000, outputXYPoints: 10000, diff --git a/packages/squiggle-lang/src/rescript/ProgramEvaluator.res b/packages/squiggle-lang/src/rescript/ProgramEvaluator.res index 134daec3..f69d2c43 100644 --- a/packages/squiggle-lang/src/rescript/ProgramEvaluator.res +++ b/packages/squiggle-lang/src/rescript/ProgramEvaluator.res @@ -14,7 +14,7 @@ module Inputs = { type inputs = { squiggleString: string, samplingInputs: SamplingInputs.t, - environment: ASTTypes.AST.environment, + environment: ASTTypes.environment, } let empty: SamplingInputs.t = { @@ -27,7 +27,7 @@ module Inputs = { let make = ( ~samplingInputs=empty, ~squiggleString, - ~environment=ASTTypes.AST.Environment.empty, + ~environment=ASTTypes.Environment.empty, (), ): inputs => { samplingInputs: samplingInputs, @@ -40,8 +40,8 @@ type \"export" = [ | #DistPlus(DistPlus.t) | #Float(float) | #Function( - (array, ASTTypes.AST.node), - ASTTypes.AST.environment, + (array, ASTTypes.node), + ASTTypes.environment, ) ] @@ -53,13 +53,13 @@ module Internals = { ): Inputs.inputs => { samplingInputs: samplingInputs, squiggleString: squiggleString, - environment: ASTTypes.AST.Environment.update(environment, str, _ => Some( + environment: ASTTypes.Environment.update(environment, str, _ => Some( node, )), } type outputs = { - graph: ASTTypes.AST.node, + graph: ASTTypes.node, pointSetDist: PointSetTypes.pointSetDist, } let makeOutputs = (graph, pointSetDist): outputs => {graph: graph, pointSetDist: pointSetDist} @@ -74,7 +74,7 @@ module Internals = { let runNode = (inputs, node) => AST.toLeaf(makeInputs(inputs), inputs.environment, node) - let runProgram = (inputs: Inputs.inputs, p: ASTTypes.AST.program) => { + let runProgram = (inputs: Inputs.inputs, p: ASTTypes.program) => { let ins = ref(inputs) p |> E.A.fmap(x => @@ -97,8 +97,8 @@ module Internals = { DistPlus.make(~pointSetDist, ~squiggleString=Some(inputs.squiggleString), ()) } -let renderIfNeeded = (inputs: Inputs.inputs, node: ASTTypes.AST.node): result< - ASTTypes.AST.node, +let renderIfNeeded = (inputs: Inputs.inputs, node: ASTTypes.node): result< + ASTTypes.node, string, > => node |> ( @@ -121,11 +121,11 @@ let renderIfNeeded = (inputs: Inputs.inputs, node: ASTTypes.AST.node): result< } ) -// TODO: Consider using ASTTypes.AST.getFloat or similar in this function +// TODO: Consider using ASTTypes.getFloat or similar in this function let coersionToExportedTypes = ( inputs, - env: ASTTypes.AST.environment, - node: ASTTypes.AST.node, + env: ASTTypes.environment, + node: ASTTypes.node, ): result<\"export", string> => node |> renderIfNeeded(inputs) @@ -160,7 +160,7 @@ let evaluateProgram = (inputs: Inputs.inputs) => let evaluateFunction = ( inputs: Inputs.inputs, - fn: (array, ASTTypes.AST.node), + fn: (array, ASTTypes.node), fnInputs, ) => { let output = AST.runFunction( diff --git a/packages/squiggle-lang/src/rescript/interpreter/AST.res b/packages/squiggle-lang/src/rescript/interpreter/AST.res index 4b105f1d..2dca6ffc 100644 --- a/packages/squiggle-lang/src/rescript/interpreter/AST.res +++ b/packages/squiggle-lang/src/rescript/interpreter/AST.res @@ -1,6 +1,6 @@ -open ASTTypes.AST +open ASTTypes -let toString = ASTTypes.AST.Node.toString +let toString = ASTTypes.Node.toString let envs = (samplingInputs, environment) => { samplingInputs: samplingInputs, @@ -18,7 +18,7 @@ let toPointSetDist = (samplingInputs, environment, node: node) => | Error(e) => Error(e) } -let runFunction = (samplingInputs, environment, inputs, fn: ASTTypes.AST.Function.t) => { +let runFunction = (samplingInputs, environment, inputs, fn: ASTTypes.Function.t) => { let params = envs(samplingInputs, environment) - ASTTypes.AST.Function.run(params, inputs, fn) -} + ASTTypes.Function.run(params, inputs, fn) +} \ No newline at end of file diff --git a/packages/squiggle-lang/src/rescript/interpreter/ASTEvaluator.res b/packages/squiggle-lang/src/rescript/interpreter/ASTEvaluator.res index 481bd382..5cadff89 100644 --- a/packages/squiggle-lang/src/rescript/interpreter/ASTEvaluator.res +++ b/packages/squiggle-lang/src/rescript/interpreter/ASTEvaluator.res @@ -1,4 +1,4 @@ -open ASTTypes.AST +open ASTTypes type t = node type tResult = node => result @@ -43,12 +43,12 @@ module AlgebraicCombination = { let combine = (evaluationParams, algebraicOp, t1: node, t2: node): result => E.R.merge( - ASTTypes.AST.SamplingDistribution.renderIfIsNotSamplingDistribution(evaluationParams, t1), - ASTTypes.AST.SamplingDistribution.renderIfIsNotSamplingDistribution(evaluationParams, t2), + ASTTypes.SamplingDistribution.renderIfIsNotSamplingDistribution(evaluationParams, t1), + ASTTypes.SamplingDistribution.renderIfIsNotSamplingDistribution(evaluationParams, t2), ) |> E.R.bind(_, ((a, b)) => switch choose(a, b) { | #Sampling => - ASTTypes.AST.SamplingDistribution.combineShapesUsingSampling( + ASTTypes.SamplingDistribution.combineShapesUsingSampling( evaluationParams, algebraicOp, a, @@ -123,7 +123,7 @@ module PointwiseCombination = { module Truncate = { type simplificationResult = [ - | #Solution(ASTTypes.AST.node) + | #Solution(ASTTypes.node) | #Error(string) | #NoSolution ] @@ -171,7 +171,7 @@ module Normalize = { switch t { | #RenderedDist(s) => Ok(#RenderedDist(PointSetDist.T.normalize(s))) | #SymbolicDist(_) => Ok(t) - | _ => ASTTypes.AST.Node.evaluateAndRetry(evaluationParams, operationToLeaf, t) + | _ => ASTTypes.Node.evaluateAndRetry(evaluationParams, operationToLeaf, t) } } @@ -181,13 +181,13 @@ module FunctionCall = { let _runLocalFunction = (name, evaluationParams: evaluationParams, args) => Environment.getFunction(evaluationParams.environment, name) |> E.R.bind(_, ((argNames, fn)) => - ASTTypes.AST.Function.run(evaluationParams, args, (argNames, fn)) + ASTTypes.Function.run(evaluationParams, args, (argNames, fn)) ) let _runWithEvaluatedInputs = ( - evaluationParams: ASTTypes.AST.evaluationParams, + evaluationParams: ASTTypes.evaluationParams, name, - args: array, + args: array, ) => _runHardcodedFunction(name, evaluationParams, args) |> E.O.default( _runLocalFunction(name, evaluationParams, args), @@ -212,7 +212,7 @@ module Render = { ), ) | #RenderedDist(_) as t => Ok(t) // already a rendered pointSetDist, we're done here - | _ => ASTTypes.AST.Node.evaluateAndRetry(evaluationParams, operationToLeaf, t) + | _ => ASTTypes.Node.evaluateAndRetry(evaluationParams, operationToLeaf, t) } } @@ -222,7 +222,7 @@ module Render = { but most often it will produce a RenderedDist. This function is used mainly to turn a parse tree into a single RenderedDist that can then be displayed to the user. */ -let rec toLeaf = (evaluationParams: ASTTypes.AST.evaluationParams, node: t): result => +let rec toLeaf = (evaluationParams: ASTTypes.evaluationParams, node: t): result => switch node { // Leaf nodes just stay leaf nodes | #SymbolicDist(_) @@ -248,7 +248,7 @@ let rec toLeaf = (evaluationParams: ASTTypes.AST.evaluationParams, node: t): res |> E.A.R.firstErrorOrOpen |> E.R.fmap(r => #Hash(r)) | #Symbol(r) => - ASTTypes.AST.Environment.get(evaluationParams.environment, r) + ASTTypes.Environment.get(evaluationParams.environment, r) |> E.O.toResult("Undeclared variable " ++ r) |> E.R.bind(_, toLeaf(evaluationParams)) | #FunctionCall(name, args) => diff --git a/packages/squiggle-lang/src/rescript/interpreter/ASTTypes.res b/packages/squiggle-lang/src/rescript/interpreter/ASTTypes.res index 8526e9ec..95498027 100644 --- a/packages/squiggle-lang/src/rescript/interpreter/ASTTypes.res +++ b/packages/squiggle-lang/src/rescript/interpreter/ASTTypes.res @@ -1,272 +1,257 @@ -module AST = { - type rec hash = array<(string, node)> - and node = [ - | #SymbolicDist(SymbolicDistTypes.symbolicDist) - | #RenderedDist(PointSetTypes.pointSetDist) - | #Symbol(string) - | #Hash(hash) - | #Array(array) - | #Function(array, node) - | #AlgebraicCombination(Operation.algebraicOperation, node, node) - | #PointwiseCombination(Operation.pointwiseOperation, node, node) - | #Normalize(node) - | #Render(node) - | #Truncate(option, option, node) - | #FunctionCall(string, array) - ] +type rec hash = array<(string, node)> +and node = [ + | #SymbolicDist(SymbolicDistTypes.symbolicDist) + | #RenderedDist(PointSetTypes.pointSetDist) + | #Symbol(string) + | #Hash(hash) + | #Array(array) + | #Function(array, node) + | #AlgebraicCombination(Operation.algebraicOperation, node, node) + | #PointwiseCombination(Operation.pointwiseOperation, node, node) + | #Normalize(node) + | #Render(node) + | #Truncate(option, option, node) + | #FunctionCall(string, array) +] - type statement = [ - | #Assignment(string, node) - | #Expression(node) - ] - type program = array +type statement = [ + | #Assignment(string, node) + | #Expression(node) +] +type program = array - type environment = Belt.Map.String.t +type environment = Belt.Map.String.t - type rec evaluationParams = { - samplingInputs: SamplingInputs.samplingInputs, - environment: environment, - evaluateNode: (evaluationParams, node) => Belt.Result.t, - } +type rec evaluationParams = { + samplingInputs: SamplingInputs.samplingInputs, + environment: environment, + evaluateNode: (evaluationParams, node) => Belt.Result.t, +} - module Environment = { - type t = environment - module MS = Belt.Map.String - let fromArray = MS.fromArray - let empty: t = []->fromArray - let mergeKeepSecond = (a: t, b: t) => - MS.merge(a, b, (_, a, b) => - switch (a, b) { - | (_, Some(b)) => Some(b) - | (Some(a), _) => Some(a) +module Environment = { + type t = environment + module MS = Belt.Map.String + let fromArray = MS.fromArray + let empty: t = []->fromArray + let mergeKeepSecond = (a: t, b: t) => + MS.merge(a, b, (_, a, b) => + switch (a, b) { + | (_, Some(b)) => Some(b) + | (Some(a), _) => Some(a) + | _ => None + } + ) + let update = (t, str, fn) => MS.update(t, str, fn) + let get = (t: t, str) => MS.get(t, str) + let getFunction = (t: t, str) => + switch get(t, str) { + | Some(#Function(argNames, fn)) => Ok((argNames, fn)) + | _ => Error("Function " ++ (str ++ " not found")) + } +} + +module Node = { + let getFloat = (node: node) => + node |> ( + x => + switch x { + | #RenderedDist(Discrete({xyShape: {xs: [x], ys: [1.0]}})) => Some(x) + | #SymbolicDist(#Float(x)) => Some(x) | _ => None } - ) - let update = (t, str, fn) => MS.update(t, str, fn) - let get = (t: t, str) => MS.get(t, str) - let getFunction = (t: t, str) => - switch get(t, str) { - | Some(#Function(argNames, fn)) => Ok((argNames, fn)) - | _ => Error("Function " ++ (str ++ " not found")) - } - } + ) - module Node = { - let getFloat = (node: node) => - node |> ( - x => - switch x { - | #RenderedDist(Discrete({xyShape: {xs: [x], ys: [1.0]}})) => Some(x) - | #SymbolicDist(#Float(x)) => Some(x) - | _ => None - } - ) + let evaluate = (evaluationParams: evaluationParams) => + evaluationParams.evaluateNode(evaluationParams) - let evaluate = (evaluationParams: evaluationParams) => - evaluationParams.evaluateNode(evaluationParams) + let evaluateAndRetry = (evaluationParams, fn, node) => + node |> evaluationParams.evaluateNode(evaluationParams) |> E.R.bind(_, fn(evaluationParams)) - let evaluateAndRetry = (evaluationParams, fn, node) => - node |> evaluationParams.evaluateNode(evaluationParams) |> E.R.bind(_, fn(evaluationParams)) + let rec toString: node => string = x => + switch x { + | #SymbolicDist(d) => SymbolicDist.T.toString(d) + | #RenderedDist(_) => "[renderedShape]" + | #AlgebraicCombination(op, t1, t2) => + Operation.Algebraic.format(op, toString(t1), toString(t2)) + | #PointwiseCombination(op, t1, t2) => + Operation.Pointwise.format(op, toString(t1), toString(t2)) + | #Normalize(t) => "normalize(k" ++ (toString(t) ++ ")") + | #Truncate(lc, rc, t) => Operation.Truncate.toString(lc, rc, toString(t)) + | #Render(t) => toString(t) + | #Symbol(t) => "Symbol: " ++ t + | #FunctionCall(name, args) => + "[Function call: (" ++ + (name ++ + ((args |> E.A.fmap(toString) |> Js.String.concatMany(_, ",")) ++ ")]")) + | #Function(args, internal) => + "[Function: (" ++ ((args |> Js.String.concatMany(_, ",")) ++ (toString(internal) ++ ")]")) + | #Array(a) => "[" ++ ((a |> E.A.fmap(toString) |> Js.String.concatMany(_, ",")) ++ "]") + | #Hash(h) => + "{" ++ + ((h + |> E.A.fmap(((name, value)) => name ++ (":" ++ toString(value))) + |> Js.String.concatMany(_, ",")) ++ + "}") + } - let rec toString: node => string = x => - switch x { - | #SymbolicDist(d) => SymbolicDist.T.toString(d) - | #RenderedDist(_) => "[renderedShape]" - | #AlgebraicCombination(op, t1, t2) => - Operation.Algebraic.format(op, toString(t1), toString(t2)) - | #PointwiseCombination(op, t1, t2) => - Operation.Pointwise.format(op, toString(t1), toString(t2)) - | #Normalize(t) => "normalize(k" ++ (toString(t) ++ ")") - | #Truncate(lc, rc, t) => Operation.Truncate.toString(lc, rc, toString(t)) - | #Render(t) => toString(t) - | #Symbol(t) => "Symbol: " ++ t - | #FunctionCall(name, args) => - "[Function call: (" ++ - (name ++ - ((args |> E.A.fmap(toString) |> Js.String.concatMany(_, ",")) ++ ")]")) - | #Function(args, internal) => - "[Function: (" ++ ((args |> Js.String.concatMany(_, ",")) ++ (toString(internal) ++ ")]")) - | #Array(a) => "[" ++ ((a |> E.A.fmap(toString) |> Js.String.concatMany(_, ",")) ++ "]") - | #Hash(h) => - "{" ++ - ((h - |> E.A.fmap(((name, value)) => name ++ (":" ++ toString(value))) - |> Js.String.concatMany(_, ",")) ++ - "}") - } + let render = (evaluationParams: evaluationParams, r) => #Render(r) |> evaluate(evaluationParams) - let render = (evaluationParams: evaluationParams, r) => #Render(r) |> evaluate(evaluationParams) - - let ensureIsRendered = (params, t) => - switch t { - | #RenderedDist(_) => Ok(t) - | _ => - switch render(params, t) { - | Ok(#RenderedDist(r)) => Ok(#RenderedDist(r)) - | Ok(_) => Error("Did not render as requested") - | Error(e) => Error(e) - } - } - - let ensureIsRenderedAndGetShape = (params, t) => - switch ensureIsRendered(params, t) { - | Ok(#RenderedDist(r)) => Ok(r) + let ensureIsRendered = (params, t) => + switch t { + | #RenderedDist(_) => Ok(t) + | _ => + switch render(params, t) { + | Ok(#RenderedDist(r)) => Ok(#RenderedDist(r)) | Ok(_) => Error("Did not render as requested") | Error(e) => Error(e) } - - let toPointSetDist = (item: node) => - switch item { - | #RenderedDist(r) => Some(r) - | _ => None - } - - let _toFloat = (t: PointSetTypes.pointSetDist) => - switch t { - | Discrete({xyShape: {xs: [x], ys: [1.0]}}) => Some(#SymbolicDist(#Float(x))) - | _ => None - } - - let toFloat = (item: node): result => - item |> toPointSetDist |> E.O.bind(_, _toFloat) |> E.O.toResult("Not valid shape") - } - - module Function = { - type t = (array, node) - let fromNode: node => option = node => - switch node { - | #Function(r) => Some(r) - | _ => None - } - let argumentNames = ((a, _): t) => a - let internals = ((_, b): t) => b - let run = (evaluationParams: evaluationParams, args: array, t: t) => - if E.A.length(args) == E.A.length(argumentNames(t)) { - let newEnvironment = Belt.Array.zip(argumentNames(t), args) |> Environment.fromArray - let newEvaluationParams: evaluationParams = { - samplingInputs: evaluationParams.samplingInputs, - environment: Environment.mergeKeepSecond( - evaluationParams.environment, - newEnvironment, - ), - evaluateNode: evaluationParams.evaluateNode, - } - evaluationParams.evaluateNode(newEvaluationParams, internals(t)) - } else { - Error("Wrong number of variables") - } - } - - module Primative = { - type t = [ - | #SymbolicDist(SymbolicDistTypes.symbolicDist) - | #RenderedDist(PointSetTypes.pointSetDist) - | #Function(array, node) - ] - - let isPrimative: node => bool = x => - switch x { - | #SymbolicDist(_) - | #RenderedDist(_) - | #Function(_) => true - | _ => false - } - - let fromNode: node => option = x => - switch x { - | #SymbolicDist(_) as n - | #RenderedDist(_) as n - | #Function(_) as n => - Some(n) - | _ => None - } - } - - module SamplingDistribution = { - type t = [ - | #SymbolicDist(SymbolicDistTypes.symbolicDist) - | #RenderedDist(PointSetTypes.pointSetDist) - ] - - let isSamplingDistribution: node => bool = x => - switch x { - | #SymbolicDist(_) => true - | #RenderedDist(_) => true - | _ => false - } - - let fromNode: node => result = x => - switch x { - | #SymbolicDist(n) => Ok(#SymbolicDist(n)) - | #RenderedDist(n) => Ok(#RenderedDist(n)) - | _ => Error("Not valid type") - } - - let renderIfIsNotSamplingDistribution = (params, t): result => - !isSamplingDistribution(t) - ? switch Node.render(params, t) { - | Ok(r) => Ok(r) - | Error(e) => Error(e) - } - : Ok(t) - - let map = (~renderedDistFn, ~symbolicDistFn, node: node) => - node |> ( - x => - switch x { - | #RenderedDist(r) => Some(renderedDistFn(r)) - | #SymbolicDist(s) => Some(symbolicDistFn(s)) - | _ => None - } - ) - - let sampleN = n => - map( - ~renderedDistFn=PointSetDist.sampleNRendered(n), - ~symbolicDistFn=SymbolicDist.T.sampleN(n), - ) - - let getCombinationSamples = (n, algebraicOp, t1: node, t2: node) => - switch (sampleN(n, t1), sampleN(n, t2)) { - | (Some(a), Some(b)) => - Some( - Belt.Array.zip(a, b) |> E.A.fmap(((a, b)) => Operation.Algebraic.toFn(algebraicOp, a, b)), - ) - | _ => None - } - - let combineShapesUsingSampling = ( - evaluationParams: evaluationParams, - algebraicOp, - t1: node, - t2: node, - ) => { - let i1 = renderIfIsNotSamplingDistribution(evaluationParams, t1) - let i2 = renderIfIsNotSamplingDistribution(evaluationParams, t2) - E.R.merge(i1, i2) |> E.R.bind(_, ((a, b)) => { - let samples = getCombinationSamples( - evaluationParams.samplingInputs.sampleCount, - algebraicOp, - a, - b, - ) - - let pointSetDist = - samples - |> E.O.fmap(r => - SampleSet.toPointSetDist( - ~samplingInputs=evaluationParams.samplingInputs, - ~samples=r, - (), - ) - ) - |> E.O.bind(_, r => r.pointSetDist) - |> E.O.toResult("No response") - pointSetDist |> E.R.fmap(r => #Normalize(#RenderedDist(r))) - }) - - // todo: This bottom part should probably be somewhere else. - // todo: REFACTOR: I'm not sure about the SampleSet line. } - } + + let ensureIsRenderedAndGetShape = (params, t) => + switch ensureIsRendered(params, t) { + | Ok(#RenderedDist(r)) => Ok(r) + | Ok(_) => Error("Did not render as requested") + | Error(e) => Error(e) + } + + let toPointSetDist = (item: node) => + switch item { + | #RenderedDist(r) => Some(r) + | _ => None + } + + let _toFloat = (t: PointSetTypes.pointSetDist) => + switch t { + | Discrete({xyShape: {xs: [x], ys: [1.0]}}) => Some(#SymbolicDist(#Float(x))) + | _ => None + } + + let toFloat = (item: node): result => + item |> toPointSetDist |> E.O.bind(_, _toFloat) |> E.O.toResult("Not valid shape") } + +module Function = { + type t = (array, node) + let fromNode: node => option = node => + switch node { + | #Function(r) => Some(r) + | _ => None + } + let argumentNames = ((a, _): t) => a + let internals = ((_, b): t) => b + let run = (evaluationParams: evaluationParams, args: array, t: t) => + if E.A.length(args) == E.A.length(argumentNames(t)) { + let newEnvironment = Belt.Array.zip(argumentNames(t), args) |> Environment.fromArray + let newEvaluationParams: evaluationParams = { + samplingInputs: evaluationParams.samplingInputs, + environment: Environment.mergeKeepSecond(evaluationParams.environment, newEnvironment), + evaluateNode: evaluationParams.evaluateNode, + } + evaluationParams.evaluateNode(newEvaluationParams, internals(t)) + } else { + Error("Wrong number of variables") + } +} + +module Primative = { + type t = [ + | #SymbolicDist(SymbolicDistTypes.symbolicDist) + | #RenderedDist(PointSetTypes.pointSetDist) + | #Function(array, node) + ] + + let isPrimative: node => bool = x => + switch x { + | #SymbolicDist(_) + | #RenderedDist(_) + | #Function(_) => true + | _ => false + } + + let fromNode: node => option = x => + switch x { + | #SymbolicDist(_) as n + | #RenderedDist(_) as n + | #Function(_) as n => + Some(n) + | _ => None + } +} + +module SamplingDistribution = { + type t = [ + | #SymbolicDist(SymbolicDistTypes.symbolicDist) + | #RenderedDist(PointSetTypes.pointSetDist) + ] + + let isSamplingDistribution: node => bool = x => + switch x { + | #SymbolicDist(_) => true + | #RenderedDist(_) => true + | _ => false + } + + let fromNode: node => result = x => + switch x { + | #SymbolicDist(n) => Ok(#SymbolicDist(n)) + | #RenderedDist(n) => Ok(#RenderedDist(n)) + | _ => Error("Not valid type") + } + + let renderIfIsNotSamplingDistribution = (params, t): result => + !isSamplingDistribution(t) + ? switch Node.render(params, t) { + | Ok(r) => Ok(r) + | Error(e) => Error(e) + } + : Ok(t) + + let map = (~renderedDistFn, ~symbolicDistFn, node: node) => + node |> ( + x => + switch x { + | #RenderedDist(r) => Some(renderedDistFn(r)) + | #SymbolicDist(s) => Some(symbolicDistFn(s)) + | _ => None + } + ) + + let sampleN = n => + map(~renderedDistFn=PointSetDist.sampleNRendered(n), ~symbolicDistFn=SymbolicDist.T.sampleN(n)) + + let getCombinationSamples = (n, algebraicOp, t1: node, t2: node) => + switch (sampleN(n, t1), sampleN(n, t2)) { + | (Some(a), Some(b)) => + Some( + Belt.Array.zip(a, b) |> E.A.fmap(((a, b)) => Operation.Algebraic.toFn(algebraicOp, a, b)), + ) + | _ => None + } + + let combineShapesUsingSampling = ( + evaluationParams: evaluationParams, + algebraicOp, + t1: node, + t2: node, + ) => { + let i1 = renderIfIsNotSamplingDistribution(evaluationParams, t1) + let i2 = renderIfIsNotSamplingDistribution(evaluationParams, t2) + E.R.merge(i1, i2) |> E.R.bind(_, ((a, b)) => { + let samples = getCombinationSamples( + evaluationParams.samplingInputs.sampleCount, + algebraicOp, + a, + b, + ) + + let pointSetDist = + samples + |> E.O.fmap(r => + SampleSet.toPointSetDist(~samplingInputs=evaluationParams.samplingInputs, ~samples=r, ()) + ) + |> E.O.bind(_, r => r.pointSetDist) + |> E.O.toResult("No response") + pointSetDist |> E.R.fmap(r => #Normalize(#RenderedDist(r))) + }) + } +} \ No newline at end of file diff --git a/packages/squiggle-lang/src/rescript/interpreter/typeSystem/TypeSystem.res b/packages/squiggle-lang/src/rescript/interpreter/typeSystem/TypeSystem.res index 16f6e011..51332106 100644 --- a/packages/squiggle-lang/src/rescript/interpreter/typeSystem/TypeSystem.res +++ b/packages/squiggle-lang/src/rescript/interpreter/typeSystem/TypeSystem.res @@ -1,5 +1,5 @@ -type node = ASTTypes.AST.node -let getFloat = ASTTypes.AST.Node.getFloat +type node = ASTTypes.node +let getFloat = ASTTypes.Node.getFloat type samplingDist = [ | #SymbolicDist(SymbolicDistTypes.symbolicDist) @@ -61,7 +61,7 @@ module TypedValue = { |> E.A.fmap(((name, t)) => fromNode(t) |> E.R.fmap(r => (name, r))) |> E.A.R.firstErrorOrOpen |> E.R.fmap(r => #Hash(r)) - | e => Error("Wrong type: " ++ ASTTypes.AST.Node.toString(e)) + | e => Error("Wrong type: " ++ ASTTypes.Node.toString(e)) } // todo: Arrays and hashes @@ -73,12 +73,12 @@ module TypedValue = { | _ => Error("Type Error: Expected float.") } | (#SamplingDistribution, _) => - ASTTypes.AST.SamplingDistribution.renderIfIsNotSamplingDistribution( + ASTTypes.SamplingDistribution.renderIfIsNotSamplingDistribution( evaluationParams, node, ) |> E.R.bind(_, fromNode) | (#RenderedDistribution, _) => - ASTTypes.AST.Node.render(evaluationParams, node) |> E.R.bind(_, fromNode) + ASTTypes.Node.render(evaluationParams, node) |> E.R.bind(_, fromNode) | (#Array(_type), #Array(b)) => b |> E.A.fmap(fromNodeWithTypeCoercion(evaluationParams, _type)) @@ -172,7 +172,7 @@ module Function = { |> E.A.R.firstErrorOrOpen let inputsToTypedValues = ( - evaluationParams: ASTTypes.AST.evaluationParams, + evaluationParams: ASTTypes.evaluationParams, inputNodes: inputNodes, t: t, ) => @@ -181,7 +181,7 @@ module Function = { ) let run = ( - evaluationParams: ASTTypes.AST.evaluationParams, + evaluationParams: ASTTypes.evaluationParams, inputNodes: inputNodes, t: t, ) => diff --git a/packages/squiggle-lang/src/rescript/parser/Parser.res b/packages/squiggle-lang/src/rescript/parser/Parser.res index aebfa1de..28e1d42a 100644 --- a/packages/squiggle-lang/src/rescript/parser/Parser.res +++ b/packages/squiggle-lang/src/rescript/parser/Parser.res @@ -122,7 +122,7 @@ module MathAdtToDistDst = { | _ => Error("Lognormal distribution needs either mean and stdev or mu and sigma") } | _ => - parseArgs() |> E.R.fmap((args: array) => + parseArgs() |> E.R.fmap((args: array) => #FunctionCall("lognormal", args) ) } @@ -130,8 +130,8 @@ module MathAdtToDistDst = { // Error("Dotwise exponentiation needs two operands") let operationParser = ( name: string, - args: result, string>, - ): result => { + args: result, string>, + ): result => { let toOkAlgebraic = r => Ok(#AlgebraicCombination(r)) let toOkPointwise = r => Ok(#PointwiseCombination(r)) let toOkTruncate = r => Ok(#Truncate(r)) @@ -170,12 +170,12 @@ module MathAdtToDistDst = { let functionParser = ( nodeParser: MathJsonToMathJsAdt.arg => Belt.Result.t< - ASTTypes.AST.node, + ASTTypes.node, string, >, name: string, args: array, - ): result => { + ): result => { let parseArray = ags => ags |> E.A.fmap(nodeParser) |> E.A.R.firstErrorOrOpen let parseArgs = () => parseArray(args) switch name { @@ -212,27 +212,27 @@ module MathAdtToDistDst = { | (Some(Error(r)), _) => Error(r) | (_, Error(r)) => Error(r) | (None, Ok(dists)) => - let hash: ASTTypes.AST.node = #FunctionCall( + let hash: ASTTypes.node = #FunctionCall( "multimodal", [#Hash([("dists", #Array(dists)), ("weights", #Array([]))])], ) Ok(hash) | (Some(Ok(weights)), Ok(dists)) => - let hash: ASTTypes.AST.node = #FunctionCall( + let hash: ASTTypes.node = #FunctionCall( "multimodal", [#Hash([("dists", #Array(dists)), ("weights", #Array(weights))])], ) Ok(hash) } | name => - parseArgs() |> E.R.fmap((args: array) => + parseArgs() |> E.R.fmap((args: array) => #FunctionCall(name, args) ) } } let rec nodeParser: MathJsonToMathJsAdt.arg => result< - ASTTypes.AST.node, + ASTTypes.node, string, > = x => switch x { @@ -246,7 +246,7 @@ module MathAdtToDistDst = { // let evaluatedExpression = run(expression); // `Function(_ => Ok(evaluatedExpression)); // } - let rec topLevel = (r): result => + let rec topLevel = (r): result => switch r { | FunctionAssignment({name, args, expression}) => switch nodeParser(expression) { @@ -267,7 +267,7 @@ module MathAdtToDistDst = { blocks |> E.A.fmap(b => topLevel(b)) |> E.A.R.firstErrorOrOpen |> E.R.fmap(E.A.concatMany) } - let run = (r): result => + let run = (r): result => r |> MathAdtCleaner.run |> topLevel }