Quick bug fixes for pow, log, and toPointSet
This commit is contained in:
		
							parent
							
								
									cf9c12f786
								
							
						
					
					
						commit
						81e478ba49
					
				| 
						 | 
					@ -83,7 +83,8 @@ let toPointSet = (
 | 
				
			||||||
        pointSetDistLength: xyPointLength,
 | 
					        pointSetDistLength: xyPointLength,
 | 
				
			||||||
        kernelWidth: None,
 | 
					        kernelWidth: None,
 | 
				
			||||||
      },
 | 
					      },
 | 
				
			||||||
    )->GenericDist_Types.Error.resultStringToResultError
 | 
					    )
 | 
				
			||||||
 | 
					    ->GenericDist_Types.Error.resultStringToResultError
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -161,10 +162,12 @@ module AlgebraicCombination = {
 | 
				
			||||||
    arithmeticOperation: GenericDist_Types.Operation.arithmeticOperation,
 | 
					    arithmeticOperation: GenericDist_Types.Operation.arithmeticOperation,
 | 
				
			||||||
    t1: t,
 | 
					    t1: t,
 | 
				
			||||||
    t2: t,
 | 
					    t2: t,
 | 
				
			||||||
  ) =>
 | 
					  ) => {
 | 
				
			||||||
 | 
					    let normalize = PointSetDist.T.normalize
 | 
				
			||||||
    E.R.merge(toPointSet(t1), toPointSet(t2))->E.R2.fmap(((a, b)) =>
 | 
					    E.R.merge(toPointSet(t1), toPointSet(t2))->E.R2.fmap(((a, b)) =>
 | 
				
			||||||
      PointSetDist.combineAlgebraically(arithmeticOperation, a, b)
 | 
					      PointSetDist.combineAlgebraically(arithmeticOperation, normalize(a), normalize(b))->normalize
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  let runMonteCarlo = (
 | 
					  let runMonteCarlo = (
 | 
				
			||||||
    toSampleSet: toSampleSetFn,
 | 
					    toSampleSet: toSampleSetFn,
 | 
				
			||||||
| 
						 | 
					@ -196,6 +199,50 @@ module AlgebraicCombination = {
 | 
				
			||||||
      ? #CalculateWithMonteCarlo
 | 
					      ? #CalculateWithMonteCarlo
 | 
				
			||||||
      : #CalculateWithConvolution
 | 
					      : #CalculateWithConvolution
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  let getLogarithmInputError = (t1: t, t2: t, ~toPointSetFn: toPointSetFn): option<error> => {
 | 
				
			||||||
 | 
					    let firstOperandIsGreaterThanZero =
 | 
				
			||||||
 | 
					      toFloatOperation(t1, ~toPointSetFn, ~distToFloatOperation=#Cdf(1e-10)) |> E.R.fmap(r =>
 | 
				
			||||||
 | 
					        r > 0.
 | 
				
			||||||
 | 
					      )
 | 
				
			||||||
 | 
					    let secondOperandIsGreaterThanZero =
 | 
				
			||||||
 | 
					      toFloatOperation(t2, ~toPointSetFn, ~distToFloatOperation=#Cdf(1e-10)) |> E.R.fmap(r =>
 | 
				
			||||||
 | 
					        r > 0.
 | 
				
			||||||
 | 
					      )
 | 
				
			||||||
 | 
					    let secondOperandHasMassAt1 =
 | 
				
			||||||
 | 
					      toFloatOperation(t2, ~toPointSetFn, ~distToFloatOperation=#Pdf(1.0)) |> E.R.fmap(r =>
 | 
				
			||||||
 | 
					        r >= 1e-10
 | 
				
			||||||
 | 
					      )
 | 
				
			||||||
 | 
					    let items = E.A.R.firstErrorOrOpen([
 | 
				
			||||||
 | 
					      firstOperandIsGreaterThanZero,
 | 
				
			||||||
 | 
					      secondOperandIsGreaterThanZero,
 | 
				
			||||||
 | 
					      secondOperandHasMassAt1,
 | 
				
			||||||
 | 
					    ])
 | 
				
			||||||
 | 
					    Js.log2("PMASS", toFloatOperation(t2, ~toPointSetFn, ~distToFloatOperation=#Pdf(1.0)))
 | 
				
			||||||
 | 
					    Js.log4("HIHI", items, t1, t2)
 | 
				
			||||||
 | 
					    switch items {
 | 
				
			||||||
 | 
					    | Error(r) => Some(r)
 | 
				
			||||||
 | 
					    | Ok([true, _, _]) => Some(Other("First input of logarithm must be fully greater than 0"))
 | 
				
			||||||
 | 
					    | Ok([false, true, _]) => Some(Other("Second input of logarithm must be fully greater than 0"))
 | 
				
			||||||
 | 
					    | Ok([false, false, true]) =>
 | 
				
			||||||
 | 
					      Some(Other("Second input of logarithm cannot have probability mass at 1.0"))
 | 
				
			||||||
 | 
					    | Ok([false, false, false]) => None
 | 
				
			||||||
 | 
					    | Ok(_) => Some(Unreachable)
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  let getInvalidOperationError = (
 | 
				
			||||||
 | 
					    t1: t,
 | 
				
			||||||
 | 
					    t2: t,
 | 
				
			||||||
 | 
					    ~toPointSetFn: toPointSetFn,
 | 
				
			||||||
 | 
					    ~arithmeticOperation,
 | 
				
			||||||
 | 
					  ): option<error> => {
 | 
				
			||||||
 | 
					    if arithmeticOperation == #Logarithm {
 | 
				
			||||||
 | 
					      getLogarithmInputError(t1, t2, ~toPointSetFn)
 | 
				
			||||||
 | 
					    } else {
 | 
				
			||||||
 | 
					      None
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  let run = (
 | 
					  let run = (
 | 
				
			||||||
    t1: t,
 | 
					    t1: t,
 | 
				
			||||||
    ~toPointSetFn: toPointSetFn,
 | 
					    ~toPointSetFn: toPointSetFn,
 | 
				
			||||||
| 
						 | 
					@ -207,15 +254,19 @@ module AlgebraicCombination = {
 | 
				
			||||||
    | Some(Ok(symbolicDist)) => Ok(Symbolic(symbolicDist))
 | 
					    | Some(Ok(symbolicDist)) => Ok(Symbolic(symbolicDist))
 | 
				
			||||||
    | Some(Error(e)) => Error(Other(e))
 | 
					    | Some(Error(e)) => Error(Other(e))
 | 
				
			||||||
    | None =>
 | 
					    | None =>
 | 
				
			||||||
      switch chooseConvolutionOrMonteCarlo(t1, t2) {
 | 
					      switch getInvalidOperationError(t1, t2, ~toPointSetFn, ~arithmeticOperation) {
 | 
				
			||||||
      | #CalculateWithMonteCarlo => runMonteCarlo(toSampleSetFn, arithmeticOperation, t1, t2)
 | 
					      | Some(e) => Error(e)
 | 
				
			||||||
      | #CalculateWithConvolution =>
 | 
					      | None =>
 | 
				
			||||||
        runConvolution(
 | 
					        switch chooseConvolutionOrMonteCarlo(t1, t2) {
 | 
				
			||||||
          toPointSetFn,
 | 
					        | #CalculateWithMonteCarlo => runMonteCarlo(toSampleSetFn, arithmeticOperation, t1, t2)
 | 
				
			||||||
          arithmeticOperation,
 | 
					        | #CalculateWithConvolution =>
 | 
				
			||||||
          t1,
 | 
					          runConvolution(
 | 
				
			||||||
          t2,
 | 
					            toPointSetFn,
 | 
				
			||||||
        )->E.R2.fmap(r => DistributionTypes.PointSet(r))
 | 
					            arithmeticOperation,
 | 
				
			||||||
 | 
					            t1,
 | 
				
			||||||
 | 
					            t2,
 | 
				
			||||||
 | 
					          )->E.R2.fmap(r => DistributionTypes.PointSet(PointSetDist.T.normalize(r)))
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
      }
 | 
					      }
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -93,7 +93,20 @@ module T = Dist({
 | 
				
			||||||
      t,
 | 
					      t,
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  let normalize = fmap((Mixed.T.normalize, Discrete.T.normalize, Continuous.T.normalize))
 | 
					  let integralEndY = mapToAll((
 | 
				
			||||||
 | 
					    Mixed.T.Integral.sum,
 | 
				
			||||||
 | 
					    Discrete.T.Integral.sum,
 | 
				
			||||||
 | 
					    Continuous.T.Integral.sum,
 | 
				
			||||||
 | 
					  ))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  let isNormalized = t => integralEndY(t) == 1.0
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  let normalize = (t: t): t =>
 | 
				
			||||||
 | 
					    if isNormalized(t) {
 | 
				
			||||||
 | 
					      t
 | 
				
			||||||
 | 
					    } else {
 | 
				
			||||||
 | 
					      t |> fmap((Mixed.T.normalize, Discrete.T.normalize, Continuous.T.normalize))
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  let updateIntegralCache = (integralCache, t: t): t =>
 | 
					  let updateIntegralCache = (integralCache, t: t): t =>
 | 
				
			||||||
    fmap(
 | 
					    fmap(
 | 
				
			||||||
| 
						 | 
					@ -124,11 +137,6 @@ module T = Dist({
 | 
				
			||||||
    Discrete.T.Integral.get,
 | 
					    Discrete.T.Integral.get,
 | 
				
			||||||
    Continuous.T.Integral.get,
 | 
					    Continuous.T.Integral.get,
 | 
				
			||||||
  ))
 | 
					  ))
 | 
				
			||||||
  let integralEndY = mapToAll((
 | 
					 | 
				
			||||||
    Mixed.T.Integral.sum,
 | 
					 | 
				
			||||||
    Discrete.T.Integral.sum,
 | 
					 | 
				
			||||||
    Continuous.T.Integral.sum,
 | 
					 | 
				
			||||||
  ))
 | 
					 | 
				
			||||||
  let integralXtoY = f =>
 | 
					  let integralXtoY = f =>
 | 
				
			||||||
    mapToAll((Mixed.T.Integral.xToY(f), Discrete.T.Integral.xToY(f), Continuous.T.Integral.xToY(f)))
 | 
					    mapToAll((Mixed.T.Integral.xToY(f), Discrete.T.Integral.xToY(f), Continuous.T.Integral.xToY(f)))
 | 
				
			||||||
  let integralYtoX = f =>
 | 
					  let integralYtoX = f =>
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -64,5 +64,10 @@ let sampleN = (t: t, n) => {
 | 
				
			||||||
//TODO: Figure out what to do if distributions are different lengths. ``zip`` is kind of inelegant for this.
 | 
					//TODO: Figure out what to do if distributions are different lengths. ``zip`` is kind of inelegant for this.
 | 
				
			||||||
let map2 = (~fn: (float, float) => float, ~t1: t, ~t2: t) => {
 | 
					let map2 = (~fn: (float, float) => float, ~t1: t, ~t2: t) => {
 | 
				
			||||||
  let samples = Belt.Array.zip(get(t1), get(t2))->E.A2.fmap(((a, b)) => fn(a, b))
 | 
					  let samples = Belt.Array.zip(get(t1), get(t2))->E.A2.fmap(((a, b)) => fn(a, b))
 | 
				
			||||||
  make(samples)
 | 
					  let has_invalid_results = Belt.Array.some(samples, a => Js.Float.isNaN(a))
 | 
				
			||||||
 | 
					  if has_invalid_results {
 | 
				
			||||||
 | 
					    Error("Distribution combination produced invalid results")
 | 
				
			||||||
 | 
					  } else {
 | 
				
			||||||
 | 
					    make(samples)
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -133,9 +133,12 @@ let toPointSetDist = (
 | 
				
			||||||
    ~discrete=Some(discrete),
 | 
					    ~discrete=Some(discrete),
 | 
				
			||||||
  )
 | 
					  )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  //The latter doesn't always produce a normalized result, so we need to normalize it.
 | 
				
			||||||
 | 
					  let normalized = pointSetDist->E.O2.fmap(PointSetDist.T.normalize)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  let samplesParse: Internals.Types.outputs = {
 | 
					  let samplesParse: Internals.Types.outputs = {
 | 
				
			||||||
    continuousParseParams: pdf |> E.O.fmap(snd),
 | 
					    continuousParseParams: pdf |> E.O.fmap(snd),
 | 
				
			||||||
    pointSetDist: pointSetDist,
 | 
					    pointSetDist: normalized,
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  samplesParse
 | 
					  samplesParse
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue
	
	Block a user