From 814a5f2c58e3fa26295952194d163adc7790a7e2 Mon Sep 17 00:00:00 2001 From: Quinn Dougherty Date: Mon, 9 May 2022 15:19:56 -0400 Subject: [PATCH] `mx` polymorphism Value: [1e-3 to 2e-2] --- .../ReducerInterface_GenericDistribution.res | 45 ++++++++++++------- 1 file changed, 30 insertions(+), 15 deletions(-) diff --git a/packages/squiggle-lang/src/rescript/ReducerInterface/ReducerInterface_GenericDistribution.res b/packages/squiggle-lang/src/rescript/ReducerInterface/ReducerInterface_GenericDistribution.res index 1175767d..fc1c7f28 100644 --- a/packages/squiggle-lang/src/rescript/ReducerInterface/ReducerInterface_GenericDistribution.res +++ b/packages/squiggle-lang/src/rescript/ReducerInterface/ReducerInterface_GenericDistribution.res @@ -120,24 +120,39 @@ module Helpers = { } let mixture = (args: array): DistributionOperation.outputType => - switch E.A.last(args) { - | Some(EvArray(b)) => { - let weights = parseNumberArray(b) - let distributions = parseDistributionArray( - Belt.Array.slice(args, ~offset=0, ~len=E.A.length(args) - 1), - ) - switch E.R.merge(distributions, weights) { - | Ok(d, w) => mixtureWithGivenWeights(d, w) + switch args { + | [EvArray(distributions)] => + switch parseDistributionArray(distributions) { + | Ok(distrs) => mixtureWithDefaultWeights(distrs) + | Error(err) => err->ArgumentError->GenDistError + } + | [EvArray(distributions), EvArray(weights)] => + switch (parseDistributionArray(distributions), parseNumberArray(weights)) { + | (Ok(distrs), Ok(wghts)) => mixtureWithGivenWeights(distrs, wghts) + | (Error(err), Ok(_)) => err->ArgumentError->GenDistError + | (Ok(_), Error(err)) => err->ArgumentError->GenDistError + | (Error(err1), Error(err2)) => `${err1}|${err2}`->ArgumentError->GenDistError + } + | _ => + switch E.A.last(args) { + | Some(EvArray(b)) => { + let weights = parseNumberArray(b) + let distributions = parseDistributionArray( + Belt.Array.slice(args, ~offset=0, ~len=E.A.length(args) - 1), + ) + switch E.R.merge(distributions, weights) { + | Ok(d, w) => mixtureWithGivenWeights(d, w) + | Error(err) => GenDistError(ArgumentError(err)) + } + } + | Some(EvNumber(_)) + | Some(EvDistribution(_)) => + switch parseDistributionArray(args) { + | Ok(distributions) => mixtureWithDefaultWeights(distributions) | Error(err) => GenDistError(ArgumentError(err)) } + | _ => GenDistError(ArgumentError("Last argument of mx must be array or distribution")) } - | Some(EvNumber(_)) - | Some(EvDistribution(_)) => - switch parseDistributionArray(args) { - | Ok(distributions) => mixtureWithDefaultWeights(distributions) - | Error(err) => GenDistError(ArgumentError(err)) - } - | _ => GenDistError(ArgumentError("Last argument of mx must be array or distribution")) } }