diff --git a/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Mixed.res b/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Mixed.res index 4d441a9a..62f9d8dd 100644 --- a/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Mixed.res +++ b/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Mixed.res @@ -161,24 +161,20 @@ module T = Dist({ let integralYtoX = (f, t) => t |> integral |> Continuous.getShape |> XYShape.YtoX.linear(f) - // This pipes all ys (continuous and discrete) through fn. - // If mapY is a linear operation, we might be able to update the integralSumCaches as well; - // if not, they'll be set to None. - let mapY = ( + let createMixedFromContinuousDiscrete = ( ~integralSumCacheFn=_ => None, ~integralCacheFn=_ => None, - ~fn: float => float, t: t, + discrete: PointSetTypes.discreteShape, + continuous: PointSetTypes.continuousShape, ): t => { let yMappedDiscrete: PointSetTypes.discreteShape = - t.discrete - |> Discrete.T.mapY(~fn) + discrete |> Discrete.updateIntegralSumCache(E.O.bind(t.discrete.integralSumCache, integralSumCacheFn)) |> Discrete.updateIntegralCache(E.O.bind(t.discrete.integralCache, integralCacheFn)) let yMappedContinuous: PointSetTypes.continuousShape = - t.continuous - |> Continuous.T.mapY(~fn) + continuous |> Continuous.updateIntegralSumCache( E.O.bind(t.continuous.integralSumCache, integralSumCacheFn), ) @@ -192,6 +188,26 @@ module T = Dist({ } } + // This pipes all ys (continuous and discrete) through fn. + // If mapY is a linear operation, we might be able to update the integralSumCaches as well; + // if not, they'll be set to None. + let mapY = ( + ~integralSumCacheFn=_ => None, + ~integralCacheFn=_ => None, + ~fn: float => float, + t: t, + ): t => { + let discrete = t.discrete |> Discrete.T.mapY(~fn) + let continuous = t.continuous |> Continuous.T.mapY(~fn) + createMixedFromContinuousDiscrete( + ~integralCacheFn, + ~integralSumCacheFn, + t, + discrete, + continuous, + ) + } + let mapYResult = ( ~integralSumCacheFn=_ => None, ~integralCacheFn=_ => None, @@ -202,27 +218,12 @@ module T = Dist({ Discrete.T.mapYResult(~fn, t.discrete), Continuous.T.mapYResult(~fn, t.continuous), )->E.R2.fmap(((discreteMapped, continuousMapped)) => { - let yMappedDiscrete: PointSetTypes.discreteShape = - discreteMapped - |> Discrete.updateIntegralSumCache( - E.O.bind(t.discrete.integralSumCache, integralSumCacheFn), - ) - |> Discrete.updateIntegralCache(E.O.bind(t.discrete.integralCache, integralCacheFn)) - - let yMappedContinuous: PointSetTypes.continuousShape = - continuousMapped - |> Continuous.updateIntegralSumCache( - E.O.bind(t.continuous.integralSumCache, integralSumCacheFn), - ) - |> Continuous.updateIntegralCache(E.O.bind(t.continuous.integralCache, integralCacheFn)) - - ( - { - discrete: yMappedDiscrete, - continuous: yMappedContinuous, - integralSumCache: E.O.bind(t.integralSumCache, integralSumCacheFn), - integralCache: E.O.bind(t.integralCache, integralCacheFn), - }: t + createMixedFromContinuousDiscrete( + ~integralCacheFn, + ~integralSumCacheFn, + t, + discreteMapped, + continuousMapped, ) }) }