open ExpressionTypes; open ExpressionTypes.ExpressionTree; type t = node; type tResult = node => result(node, string); /* Given two random variables A and B, this returns the distribution of a new variable that is the result of the operation on A and B. For instance, normal(0, 1) + normal(1, 1) -> normal(1, 2). In general, this is implemented via convolution. */ module AlgebraicCombination = { let tryAnalyticalSimplification = (operation, t1: t, t2: t) => switch (operation, t1, t2) { | (operation, `SymbolicDist(d1), `SymbolicDist(d2)) => switch (SymbolicDist.T.tryAnalyticalSimplification(d1, d2, operation)) { | `AnalyticalSolution(symbolicDist) => Ok(`SymbolicDist(symbolicDist)) | `Error(er) => Error(er) | `NoSolution => Ok(`AlgebraicCombination((operation, t1, t2))) } | _ => Ok(`AlgebraicCombination((operation, t1, t2))) }; let combinationByRendering = (evaluationParams, algebraicOp, t1: node, t2: node) : result(node, string) => { E.R.merge( Render.ensureIsRenderedAndGetShape(evaluationParams, t1), Render.ensureIsRenderedAndGetShape(evaluationParams, t2), ) |> E.R.fmap(((a, b)) => `RenderedDist(Shape.combineAlgebraically(algebraicOp, a, b)) ); }; let nodeScore: node => int = fun | `SymbolicDist(`Float(_)) => 1 | `SymbolicDist(_) => 1000 | `RenderedDist(Discrete(m)) => m.xyShape |> XYShape.T.length | `RenderedDist(Mixed(_)) => 1000 | `RenderedDist(Continuous(_)) => 1000 | _ => 1000; let choose = (t1: node, t2: node) => { nodeScore(t1) * nodeScore(t2) > 10000 ? `Sampling : `Analytical; }; let combine = (evaluationParams, algebraicOp, t1: node, t2: node) : result(node, string) => { E.R.merge( PTypes.SamplingDistribution.renderIfIsNotSamplingDistribution( evaluationParams, t1, ), PTypes.SamplingDistribution.renderIfIsNotSamplingDistribution( evaluationParams, t2, ), ) |> E.R.bind(_, ((a, b)) => switch (choose(a, b)) { | `Sampling => PTypes.SamplingDistribution.combineShapesUsingSampling( evaluationParams, algebraicOp, a, b, ) | `Analytical => combinationByRendering(evaluationParams, algebraicOp, a, b) } ); }; let operationToLeaf = ( evaluationParams: evaluationParams, algebraicOp: ExpressionTypes.algebraicOperation, t1: t, t2: t, ) : result(node, string) => algebraicOp |> tryAnalyticalSimplification(_, t1, t2) |> E.R.bind( _, fun | `SymbolicDist(_) as t => Ok(t) | _ => combine(evaluationParams, algebraicOp, t1, t2), ); }; module PointwiseCombination = { let pointwiseAdd = (evaluationParams: evaluationParams, t1: t, t2: t) => { switch ( Render.render(evaluationParams, t1), Render.render(evaluationParams, t2), ) { | (Ok(`RenderedDist(rs1)), Ok(`RenderedDist(rs2))) => Ok( `RenderedDist( Shape.combinePointwise( ~integralSumCachesFn=(a, b) => Some(a +. b), ~integralCachesFn= (a, b) => Some( Continuous.combinePointwise( ~distributionType=`CDF, (+.), a, b, ), ), (+.), rs1, rs2, ), ), ) | (Error(e1), _) => Error(e1) | (_, Error(e2)) => Error(e2) | _ => Error("Pointwise combination: rendering failed.") }; }; let pointwiseCombine = (fn, evaluationParams: evaluationParams, t1: t, t2: t) => { // TODO: construct a function that we can easily sample from, to construct // a RenderedDist. Use the xMin and xMax of the rendered shapes to tell the sampling function where to look. // TODO: This should work for symbolic distributions too! switch ( Render.render(evaluationParams, t1), Render.render(evaluationParams, t2), ) { | (Ok(`RenderedDist(rs1)), Ok(`RenderedDist(rs2))) => Ok(`RenderedDist(Shape.combinePointwise(fn, rs1, rs2))) | (Error(e1), _) => Error(e1) | (_, Error(e2)) => Error(e2) | _ => Error("Pointwise combination: rendering failed.") }; }; let operationToLeaf = ( evaluationParams: evaluationParams, pointwiseOp: pointwiseOperation, t1: t, t2: t, ) => { switch (pointwiseOp) { | `Add => pointwiseAdd(evaluationParams, t1, t2) | `Multiply => pointwiseCombine(( *. ), evaluationParams, t1, t2) | `Exponentiate => pointwiseCombine(( ** ), evaluationParams, t1, t2) }; }; }; module Truncate = { let trySimplification = (leftCutoff, rightCutoff, t): simplificationResult => { switch (leftCutoff, rightCutoff, t) { | (None, None, t) => `Solution(t) | (Some(lc), Some(rc), _) when lc > rc => `Error( "Left truncation bound must be smaller than right truncation bound.", ) | (lc, rc, `SymbolicDist(`Uniform(u))) => `Solution( `SymbolicDist(`Uniform(SymbolicDist.Uniform.truncate(lc, rc, u))), ) | _ => `NoSolution }; }; let truncateAsShape = (evaluationParams: evaluationParams, leftCutoff, rightCutoff, t) => { // TODO: use named args for xMin/xMax in renderToShape; if we're lucky we can at least get the tail // of a distribution we otherwise wouldn't get at all switch (Render.ensureIsRendered(evaluationParams, t)) { | Ok(`RenderedDist(rs)) => Ok(`RenderedDist(Shape.T.truncate(leftCutoff, rightCutoff, rs))) | Error(e) => Error(e) | _ => Error("Could not truncate distribution.") }; }; let operationToLeaf = ( evaluationParams, leftCutoff: option(float), rightCutoff: option(float), t: node, ) : result(node, string) => { t |> trySimplification(leftCutoff, rightCutoff) |> ( fun | `Solution(t) => Ok(t) | `Error(e) => Error(e) | `NoSolution => truncateAsShape(evaluationParams, leftCutoff, rightCutoff, t) ); }; }; module Normalize = { let rec operationToLeaf = (evaluationParams, t: node): result(node, string) => { Js.log2("normalize", t); switch (t) { | `RenderedDist(s) => Ok(`RenderedDist(Shape.T.normalize(s))) | `SymbolicDist(_) => Ok(t) | _ => evaluateAndRetry(evaluationParams, operationToLeaf, t) }; }; }; module FunctionCall = { let _runHardcodedFunction = (name, evaluationParams, args) => TypeSystem.Function.Ts.findByNameAndRun( HardcodedFunctions.all, name, evaluationParams, args, ); let _runLocalFunction = (name, evaluationParams: evaluationParams, args) => { Environment.getFunction(evaluationParams.environment, name) |> E.R.bind(_, ((argNames, fn)) => PTypes.Function.run(evaluationParams, args, (argNames, fn)) ); }; let _runWithEvaluatedInputs = ( evaluationParams: ExpressionTypes.ExpressionTree.evaluationParams, name, args: array(ExpressionTypes.ExpressionTree.node), ) => { _runHardcodedFunction(name, evaluationParams, args) |> E.O.default(_runLocalFunction(name, evaluationParams, args)); }; // TODO: This forces things to be floats let run = (evaluationParams, name, args) => { args |> E.A.fmap(a => evaluationParams.evaluateNode(evaluationParams, a)) |> E.A.R.firstErrorOrOpen |> E.R.bind(_, _runWithEvaluatedInputs(evaluationParams, name)); }; }; module Render = { let rec operationToLeaf = (evaluationParams: evaluationParams, t: node): result(t, string) => { switch (t) { | `Function(_) => Error("Cannot render a function") | `SymbolicDist(d) => Ok( `RenderedDist( SymbolicDist.T.toShape( evaluationParams.samplingInputs.shapeLength, d, ), ), ) | `RenderedDist(_) as t => Ok(t) // already a rendered shape, we're done here | _ => evaluateAndRetry(evaluationParams, operationToLeaf, t) }; }; }; /* This function recursively goes through the nodes of the parse tree, replacing each Operation node and its subtree with a Data node. Whenever possible, the replacement produces a new Symbolic Data node, but most often it will produce a RenderedDist. This function is used mainly to turn a parse tree into a single RenderedDist that can then be displayed to the user. */ let rec toLeaf = ( evaluationParams: ExpressionTypes.ExpressionTree.evaluationParams, node: t, ) : result(t, string) => { switch (node) { // Leaf nodes just stay leaf nodes | `SymbolicDist(_) | `Function(_) | `RenderedDist(_) => Ok(node) | `Array(args) => args |> E.A.fmap(toLeaf(evaluationParams)) |> E.A.R.firstErrorOrOpen |> E.R.fmap(r => `Array(r)) // Operations nevaluationParamsd to be turned into leaves | `AlgebraicCombination(algebraicOp, t1, t2) => AlgebraicCombination.operationToLeaf( evaluationParams, algebraicOp, t1, t2, ) | `PointwiseCombination(pointwiseOp, t1, t2) => PointwiseCombination.operationToLeaf( evaluationParams, pointwiseOp, t1, t2, ) | `Truncate(leftCutoff, rightCutoff, t) => Truncate.operationToLeaf(evaluationParams, leftCutoff, rightCutoff, t) | `Normalize(t) => Normalize.operationToLeaf(evaluationParams, t) | `Render(t) => Render.operationToLeaf(evaluationParams, t) | `Hash(t) => t |> E.A.fmap(((name: string, node: node)) => toLeaf(evaluationParams, node) |> E.R.fmap(r => (name, r)) ) |> E.A.R.firstErrorOrOpen |> E.R.fmap(r => `Hash(r)) | `Symbol(r) => ExpressionTypes.ExpressionTree.Environment.get( evaluationParams.environment, r, ) |> E.O.toResult("Undeclared variable " ++ r) |> E.R.bind(_, toLeaf(evaluationParams)) | `FunctionCall(name, args) => FunctionCall.run(evaluationParams, name, args) |> E.R.bind(_, toLeaf(evaluationParams)) } };