diff --git a/packages/squiggle-lang/src/rescript/Distributions/DistributionOperation.res b/packages/squiggle-lang/src/rescript/Distributions/DistributionOperation.res index 668fcd07..25eb70f2 100644 --- a/packages/squiggle-lang/src/rescript/Distributions/DistributionOperation.res +++ b/packages/squiggle-lang/src/rescript/Distributions/DistributionOperation.res @@ -182,6 +182,11 @@ let rec run = (~env, functionCallInfo: functionCallInfo): outputType => { ) ->E.R2.fmap(r => Dist(r)) ->OutputLocal.fromResult + | ToDist(Scale(#Multiply, f)) => + dist + ->GenericDist.pointwiseCombinationFloat(~toPointSetFn, ~algebraicCombination=#Multiply, ~f) + ->E.R2.fmap(r => Dist(r)) + ->OutputLocal.fromResult | ToDist(Scale(#Logarithm, f)) => dist ->GenericDist.pointwiseCombinationFloat(~toPointSetFn, ~algebraicCombination=#Logarithm, ~f) @@ -298,6 +303,7 @@ module Constructors = { let algebraicLogarithm = (~env, dist1, dist2) => C.algebraicLogarithm(dist1, dist2)->run(~env)->toDistR let algebraicPower = (~env, dist1, dist2) => C.algebraicPower(dist1, dist2)->run(~env)->toDistR + let scaleMultiply = (~env, dist, n) => C.scaleMultiply(dist, n)->run(~env)->toDistR let scalePower = (~env, dist, n) => C.scalePower(dist, n)->run(~env)->toDistR let scaleLogarithm = (~env, dist, n) => C.scaleLogarithm(dist, n)->run(~env)->toDistR let pointwiseAdd = (~env, dist1, dist2) => C.pointwiseAdd(dist1, dist2)->run(~env)->toDistR diff --git a/packages/squiggle-lang/src/rescript/Distributions/DistributionOperation.resi b/packages/squiggle-lang/src/rescript/Distributions/DistributionOperation.resi index 454b2729..bfa7b3ad 100644 --- a/packages/squiggle-lang/src/rescript/Distributions/DistributionOperation.resi +++ b/packages/squiggle-lang/src/rescript/Distributions/DistributionOperation.resi @@ -102,6 +102,8 @@ module Constructors: { @genType let scaleLogarithm: (~env: env, genericDist, float) => result @genType + let scaleMultiply: (~env: env, genericDist, float) => result + @genType let scalePower: (~env: env, genericDist, float) => result @genType let pointwiseAdd: (~env: env, genericDist, genericDist) => result diff --git a/packages/squiggle-lang/src/rescript/Distributions/DistributionTypes.res b/packages/squiggle-lang/src/rescript/Distributions/DistributionTypes.res index cdb0ee3d..a23c2cd6 100644 --- a/packages/squiggle-lang/src/rescript/Distributions/DistributionTypes.res +++ b/packages/squiggle-lang/src/rescript/Distributions/DistributionTypes.res @@ -76,6 +76,7 @@ module DistributionOperation = { ] type toScaleFn = [ + | #Multiply | #Power | #Logarithm | #LogarithmWithThreshold(float) @@ -138,6 +139,7 @@ module DistributionOperation = { | ToDist(Truncate(_, _)) => `truncate` | ToDist(Inspect) => `inspect` | ToDist(Scale(#Power, r)) => `scalePower(${E.Float.toFixed(r)})` + | ToDist(Scale(#Multiply, r)) => `scaleMultiply(${E.Float.toFixed(r)})` | ToDist(Scale(#Logarithm, r)) => `scaleLog(${E.Float.toFixed(r)})` | ToDist(Scale(#LogarithmWithThreshold(eps), r)) => `scaleLogWithThreshold(${E.Float.toFixed(r)}, epsilon=${E.Float.toFixed(eps)})` @@ -179,6 +181,7 @@ module Constructors = { ToScore(LogScore(answer, prior)), prediction, ) + let scaleMultiply = (dist, n): t => FromDist(ToDist(Scale(#Multiply, n)), dist) let scalePower = (dist, n): t => FromDist(ToDist(Scale(#Power, n)), dist) let scaleLogarithm = (dist, n): t => FromDist(ToDist(Scale(#Logarithm, n)), dist) let scaleLogarithmWithThreshold = (dist, n, eps): t => FromDist( diff --git a/packages/squiggle-lang/src/rescript/ReducerInterface/ReducerInterface_GenericDistribution.res b/packages/squiggle-lang/src/rescript/ReducerInterface/ReducerInterface_GenericDistribution.res index c0418df8..762f4125 100644 --- a/packages/squiggle-lang/src/rescript/ReducerInterface/ReducerInterface_GenericDistribution.res +++ b/packages/squiggle-lang/src/rescript/ReducerInterface/ReducerInterface_GenericDistribution.res @@ -266,6 +266,8 @@ let dispatchToGenericOutput = ( Helpers.toDistFn(Scale(#Logarithm, float), dist, ~env) | ("scaleLogWithThreshold", [EvDistribution(dist), EvNumber(base), EvNumber(eps)]) => Helpers.toDistFn(Scale(#LogarithmWithThreshold(eps), base), dist, ~env) + | ("scaleMultiply", [EvDistribution(dist), EvNumber(float)]) => + Helpers.toDistFn(Scale(#Multiply, float), dist, ~env) | ("scalePow", [EvDistribution(dist), EvNumber(float)]) => Helpers.toDistFn(Scale(#Power, float), dist, ~env) | ("scaleExp", [EvDistribution(dist)]) =>