diff --git a/packages/squiggle-lang/src/rescript/ReducerInterface/ReducerInterface_GenericDistribution.res b/packages/squiggle-lang/src/rescript/ReducerInterface/ReducerInterface_GenericDistribution.res index 1175767d..fa7be3be 100644 --- a/packages/squiggle-lang/src/rescript/ReducerInterface/ReducerInterface_GenericDistribution.res +++ b/packages/squiggle-lang/src/rescript/ReducerInterface/ReducerInterface_GenericDistribution.res @@ -119,26 +119,43 @@ module Helpers = { mixtureWithGivenWeights(distributions, weights) } - let mixture = (args: array): DistributionOperation.outputType => - switch E.A.last(args) { - | Some(EvArray(b)) => { - let weights = parseNumberArray(b) - let distributions = parseDistributionArray( - Belt.Array.slice(args, ~offset=0, ~len=E.A.length(args) - 1), - ) - switch E.R.merge(distributions, weights) { - | Ok(d, w) => mixtureWithGivenWeights(d, w) - | Error(err) => GenDistError(ArgumentError(err)) + let mixture = (args: array): DistributionOperation.outputType => { + let error = (err: string): DistributionOperation.outputType => err->ArgumentError->GenDistError + switch args { + | [EvArray(distributions)] => + switch parseDistributionArray(distributions) { + | Ok(distrs) => mixtureWithDefaultWeights(distrs) + | Error(err) => error(err) + } + | [EvArray(distributions), EvArray(weights)] => + switch (parseDistributionArray(distributions), parseNumberArray(weights)) { + | (Ok(distrs), Ok(wghts)) => mixtureWithGivenWeights(distrs, wghts) + | (Error(err), Ok(_)) => error(err) + | (Ok(_), Error(err)) => error(err) + | (Error(err1), Error(err2)) => error(`${err1}|${err2}`) + } + | _ => + switch E.A.last(args) { + | Some(EvArray(b)) => { + let weights = parseNumberArray(b) + let distributions = parseDistributionArray( + Belt.Array.slice(args, ~offset=0, ~len=E.A.length(args) - 1), + ) + switch E.R.merge(distributions, weights) { + | Ok(d, w) => mixtureWithGivenWeights(d, w) + | Error(err) => error(err) + } } + | Some(EvNumber(_)) + | Some(EvDistribution(_)) => + switch parseDistributionArray(args) { + | Ok(distributions) => mixtureWithDefaultWeights(distributions) + | Error(err) => error(err) + } + | _ => error("Last argument of mx must be array or distribution") } - | Some(EvNumber(_)) - | Some(EvDistribution(_)) => - switch parseDistributionArray(args) { - | Ok(distributions) => mixtureWithDefaultWeights(distributions) - | Error(err) => GenDistError(ArgumentError(err)) - } - | _ => GenDistError(ArgumentError("Last argument of mx must be array or distribution")) } + } } module SymbolicConstructors = {