Changed genericDist from being a polymorphic variant
This commit is contained in:
		
							parent
							
								
									4b3f24b38d
								
							
						
					
					
						commit
						680726e8b0
					
				| 
						 | 
					@ -6,10 +6,10 @@ let env: GenericDist_GenericOperation.env = {
 | 
				
			||||||
  xyPointLength: 100,
 | 
					  xyPointLength: 100,
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
let normalDist: GenericDist_Types.genericDist = #Symbolic(#Normal({mean: 5.0, stdev: 2.0}))
 | 
					let normalDist: GenericDist_Types.genericDist = Symbolic(#Normal({mean: 5.0, stdev: 2.0}))
 | 
				
			||||||
let normalDist10: GenericDist_Types.genericDist = #Symbolic(#Normal({mean: 10.0, stdev: 2.0}))
 | 
					let normalDist10: GenericDist_Types.genericDist = Symbolic(#Normal({mean: 10.0, stdev: 2.0}))
 | 
				
			||||||
let normalDist20: GenericDist_Types.genericDist = #Symbolic(#Normal({mean: 20.0, stdev: 2.0}))
 | 
					let normalDist20: GenericDist_Types.genericDist = Symbolic(#Normal({mean: 20.0, stdev: 2.0}))
 | 
				
			||||||
let uniformDist: GenericDist_Types.genericDist = #Symbolic(#Uniform({low: 9.0, high: 10.0}))
 | 
					let uniformDist: GenericDist_Types.genericDist = Symbolic(#Uniform({low: 9.0, high: 10.0}))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
let {toFloat, toDist, toString, toError} = module(GenericDist_GenericOperation.Output)
 | 
					let {toFloat, toDist, toString, toError} = module(GenericDist_GenericOperation.Output)
 | 
				
			||||||
let {run} = module(GenericDist_GenericOperation)
 | 
					let {run} = module(GenericDist_GenericOperation)
 | 
				
			||||||
| 
						 | 
					@ -57,7 +57,7 @@ describe("toPointSet", () => {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  test("on sample set distribution with under 4 points", () => {
 | 
					  test("on sample set distribution with under 4 points", () => {
 | 
				
			||||||
    let result =
 | 
					    let result =
 | 
				
			||||||
      run(#fromDist(#toDist(#toPointSet), #SampleSet([0.0, 1.0, 2.0, 3.0])))->outputMap(
 | 
					      run(#fromDist(#toDist(#toPointSet), SampleSet([0.0, 1.0, 2.0, 3.0])))->outputMap(
 | 
				
			||||||
        #fromDist(#toFloat(#Mean)),
 | 
					        #fromDist(#toFloat(#Mean)),
 | 
				
			||||||
      )
 | 
					      )
 | 
				
			||||||
    expect(result)->toEqual(GenDistError(Other("Converting sampleSet to pointSet failed")))
 | 
					    expect(result)->toEqual(GenDistError(Other("Converting sampleSet to pointSet failed")))
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -8,25 +8,25 @@ type pointwiseAddFn = (t, t) => result<t, error>
 | 
				
			||||||
 | 
					
 | 
				
			||||||
let sampleN = (t: t, n) =>
 | 
					let sampleN = (t: t, n) =>
 | 
				
			||||||
  switch t {
 | 
					  switch t {
 | 
				
			||||||
  | #PointSet(r) => Ok(PointSetDist.sampleNRendered(n, r))
 | 
					  | PointSet(r) => Ok(PointSetDist.sampleNRendered(n, r))
 | 
				
			||||||
  | #Symbolic(r) => Ok(SymbolicDist.T.sampleN(n, r))
 | 
					  | Symbolic(r) => Ok(SymbolicDist.T.sampleN(n, r))
 | 
				
			||||||
  | #SampleSet(_) => Error(GenericDist_Types.NotYetImplemented)
 | 
					  | SampleSet(_) => Error(GenericDist_Types.NotYetImplemented)
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
let fromFloat = (f: float) => #Symbolic(SymbolicDist.Float.make(f))
 | 
					let fromFloat = (f: float): t => Symbolic(SymbolicDist.Float.make(f))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
let toString = (t: t) =>
 | 
					let toString = (t: t) =>
 | 
				
			||||||
  switch t {
 | 
					  switch t {
 | 
				
			||||||
  | #PointSet(_) => "Point Set Distribution"
 | 
					  | PointSet(_) => "Point Set Distribution"
 | 
				
			||||||
  | #Symbolic(r) => SymbolicDist.T.toString(r)
 | 
					  | Symbolic(r) => SymbolicDist.T.toString(r)
 | 
				
			||||||
  | #SampleSet(_) => "Sample Set Distribution"
 | 
					  | SampleSet(_) => "Sample Set Distribution"
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
let normalize = (t: t) =>
 | 
					let normalize = (t: t): t =>
 | 
				
			||||||
  switch t {
 | 
					  switch t {
 | 
				
			||||||
  | #PointSet(r) => #PointSet(PointSetDist.T.normalize(r))
 | 
					  | PointSet(r) => PointSet(PointSetDist.T.normalize(r))
 | 
				
			||||||
  | #Symbolic(_) => t
 | 
					  | Symbolic(_) => t
 | 
				
			||||||
  | #SampleSet(_) => t
 | 
					  | SampleSet(_) => t
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
let toFloatOperation = (
 | 
					let toFloatOperation = (
 | 
				
			||||||
| 
						 | 
					@ -34,8 +34,8 @@ let toFloatOperation = (
 | 
				
			||||||
  ~toPointSetFn: toPointSetFn,
 | 
					  ~toPointSetFn: toPointSetFn,
 | 
				
			||||||
  ~distToFloatOperation: Operation.distToFloatOperation,
 | 
					  ~distToFloatOperation: Operation.distToFloatOperation,
 | 
				
			||||||
) => {
 | 
					) => {
 | 
				
			||||||
  let symbolicSolution = switch t {
 | 
					  let symbolicSolution = switch (t: t) {
 | 
				
			||||||
  | #Symbolic(r) =>
 | 
					  | Symbolic(r) =>
 | 
				
			||||||
    switch SymbolicDist.T.operate(distToFloatOperation, r) {
 | 
					    switch SymbolicDist.T.operate(distToFloatOperation, r) {
 | 
				
			||||||
    | Ok(f) => Some(f)
 | 
					    | Ok(f) => Some(f)
 | 
				
			||||||
    | _ => None
 | 
					    | _ => None
 | 
				
			||||||
| 
						 | 
					@ -53,10 +53,10 @@ let toFloatOperation = (
 | 
				
			||||||
// This is tricky because the case of discrete distributions.
 | 
					// This is tricky because the case of discrete distributions.
 | 
				
			||||||
// Also, change the outputXYPoints/pointSetDistLength details
 | 
					// Also, change the outputXYPoints/pointSetDistLength details
 | 
				
			||||||
let toPointSet = (~xyPointLength, ~sampleCount, t): result<PointSetTypes.pointSetDist, error> => {
 | 
					let toPointSet = (~xyPointLength, ~sampleCount, t): result<PointSetTypes.pointSetDist, error> => {
 | 
				
			||||||
  switch t {
 | 
					  switch (t: t) {
 | 
				
			||||||
  | #PointSet(pointSet) => Ok(pointSet)
 | 
					  | PointSet(pointSet) => Ok(pointSet)
 | 
				
			||||||
  | #Symbolic(r) => Ok(SymbolicDist.T.toPointSetDist(xyPointLength, r))
 | 
					  | Symbolic(r) => Ok(SymbolicDist.T.toPointSetDist(xyPointLength, r))
 | 
				
			||||||
  | #SampleSet(r) => {
 | 
					  | SampleSet(r) => {
 | 
				
			||||||
      let response = SampleSet.toPointSetDist(
 | 
					      let response = SampleSet.toPointSetDist(
 | 
				
			||||||
        ~samples=r,
 | 
					        ~samples=r,
 | 
				
			||||||
        ~samplingInputs={
 | 
					        ~samplingInputs={
 | 
				
			||||||
| 
						 | 
					@ -76,11 +76,11 @@ let toPointSet = (~xyPointLength, ~sampleCount, t): result<PointSetTypes.pointSe
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
module Truncate = {
 | 
					module Truncate = {
 | 
				
			||||||
  let trySymbolicSimplification = (leftCutoff, rightCutoff, t): option<t> =>
 | 
					  let trySymbolicSimplification = (leftCutoff, rightCutoff, t: t): option<t> =>
 | 
				
			||||||
    switch (leftCutoff, rightCutoff, t) {
 | 
					    switch (leftCutoff, rightCutoff, t) {
 | 
				
			||||||
    | (None, None, _) => None
 | 
					    | (None, None, _) => None
 | 
				
			||||||
    | (lc, rc, #Symbolic(#Uniform(u))) if lc < rc =>
 | 
					    | (lc, rc, Symbolic(#Uniform(u))) if lc < rc =>
 | 
				
			||||||
      Some(#Symbolic(#Uniform(SymbolicDist.Uniform.truncate(lc, rc, u))))
 | 
					      Some(Symbolic(#Uniform(SymbolicDist.Uniform.truncate(lc, rc, u))))
 | 
				
			||||||
    | _ => None
 | 
					    | _ => None
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -98,9 +98,9 @@ module Truncate = {
 | 
				
			||||||
      switch trySymbolicSimplification(leftCutoff, rightCutoff, t) {
 | 
					      switch trySymbolicSimplification(leftCutoff, rightCutoff, t) {
 | 
				
			||||||
      | Some(r) => Ok(r)
 | 
					      | Some(r) => Ok(r)
 | 
				
			||||||
      | None =>
 | 
					      | None =>
 | 
				
			||||||
        toPointSetFn(t)->E.R2.fmap(t =>
 | 
					        toPointSetFn(t)->E.R2.fmap(t => {
 | 
				
			||||||
          #PointSet(PointSetDist.T.truncate(leftCutoff, rightCutoff, t))
 | 
					          GenericDist_Types.PointSet(PointSetDist.T.truncate(leftCutoff, rightCutoff, t))
 | 
				
			||||||
        )
 | 
					        })
 | 
				
			||||||
      }
 | 
					      }
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
| 
						 | 
					@ -122,7 +122,7 @@ module AlgebraicCombination = {
 | 
				
			||||||
    t2: t,
 | 
					    t2: t,
 | 
				
			||||||
  ): option<result<SymbolicDistTypes.symbolicDist, string>> =>
 | 
					  ): option<result<SymbolicDistTypes.symbolicDist, string>> =>
 | 
				
			||||||
    switch (arithmeticOperation, t1, t2) {
 | 
					    switch (arithmeticOperation, t1, t2) {
 | 
				
			||||||
    | (arithmeticOperation, #Symbolic(d1), #Symbolic(d2)) =>
 | 
					    | (arithmeticOperation, Symbolic(d1), Symbolic(d2)) =>
 | 
				
			||||||
      switch SymbolicDist.T.tryAnalyticalSimplification(d1, d2, arithmeticOperation) {
 | 
					      switch SymbolicDist.T.tryAnalyticalSimplification(d1, d2, arithmeticOperation) {
 | 
				
			||||||
      | #AnalyticalSolution(symbolicDist) => Some(Ok(symbolicDist))
 | 
					      | #AnalyticalSolution(symbolicDist) => Some(Ok(symbolicDist))
 | 
				
			||||||
      | #Error(er) => Some(Error(er))
 | 
					      | #Error(er) => Some(Error(er))
 | 
				
			||||||
| 
						 | 
					@ -156,11 +156,11 @@ module AlgebraicCombination = {
 | 
				
			||||||
  //I'm (Ozzie) really just guessing here, very little idea what's best
 | 
					  //I'm (Ozzie) really just guessing here, very little idea what's best
 | 
				
			||||||
  let expectedConvolutionCost: t => int = x =>
 | 
					  let expectedConvolutionCost: t => int = x =>
 | 
				
			||||||
    switch x {
 | 
					    switch x {
 | 
				
			||||||
    | #Symbolic(#Float(_)) => 1
 | 
					    | Symbolic(#Float(_)) => 1
 | 
				
			||||||
    | #Symbolic(_) => 1000
 | 
					    | Symbolic(_) => 1000
 | 
				
			||||||
    | #PointSet(Discrete(m)) => m.xyShape->XYShape.T.length
 | 
					    | PointSet(Discrete(m)) => m.xyShape->XYShape.T.length
 | 
				
			||||||
    | #PointSet(Mixed(_)) => 1000
 | 
					    | PointSet(Mixed(_)) => 1000
 | 
				
			||||||
    | #PointSet(Continuous(_)) => 1000
 | 
					    | PointSet(Continuous(_)) => 1000
 | 
				
			||||||
    | _ => 1000
 | 
					    | _ => 1000
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -177,14 +177,24 @@ module AlgebraicCombination = {
 | 
				
			||||||
    ~t2: t,
 | 
					    ~t2: t,
 | 
				
			||||||
  ): result<t, error> => {
 | 
					  ): result<t, error> => {
 | 
				
			||||||
    switch tryAnalyticalSimplification(arithmeticOperation, t1, t2) {
 | 
					    switch tryAnalyticalSimplification(arithmeticOperation, t1, t2) {
 | 
				
			||||||
    | Some(Ok(symbolicDist)) => Ok(#Symbolic(symbolicDist))
 | 
					    | Some(Ok(symbolicDist)) => Ok(Symbolic(symbolicDist))
 | 
				
			||||||
    | Some(Error(e)) => Error(Other(e))
 | 
					    | Some(Error(e)) => Error(Other(e))
 | 
				
			||||||
    | None =>
 | 
					    | None =>
 | 
				
			||||||
      switch chooseConvolutionOrMonteCarlo(t1, t2) {
 | 
					      switch chooseConvolutionOrMonteCarlo(t1, t2) {
 | 
				
			||||||
      | #CalculateWithMonteCarlo =>
 | 
					      | #CalculateWithMonteCarlo =>
 | 
				
			||||||
        runMonteCarlo(toSampleSetFn, arithmeticOperation, t1, t2)->E.R2.fmap(r => #SampleSet(r))
 | 
					        runMonteCarlo(
 | 
				
			||||||
 | 
					          toSampleSetFn,
 | 
				
			||||||
 | 
					          arithmeticOperation,
 | 
				
			||||||
 | 
					          t1,
 | 
				
			||||||
 | 
					          t2,
 | 
				
			||||||
 | 
					        )->E.R2.fmap(r => GenericDist_Types.SampleSet(r))
 | 
				
			||||||
      | #CalculateWithConvolution =>
 | 
					      | #CalculateWithConvolution =>
 | 
				
			||||||
        runConvolution(toPointSetFn, arithmeticOperation, t1, t2)->E.R2.fmap(r => #PointSet(r))
 | 
					        runConvolution(
 | 
				
			||||||
 | 
					          toPointSetFn,
 | 
				
			||||||
 | 
					          arithmeticOperation,
 | 
				
			||||||
 | 
					          t1,
 | 
				
			||||||
 | 
					          t2,
 | 
				
			||||||
 | 
					        )->E.R2.fmap(r => GenericDist_Types.PointSet(r))
 | 
				
			||||||
      }
 | 
					      }
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
| 
						 | 
					@ -207,7 +217,7 @@ let pointwiseCombination = (
 | 
				
			||||||
      t2,
 | 
					      t2,
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
  )
 | 
					  )
 | 
				
			||||||
  ->E.R2.fmap(r => #PointSet(r))
 | 
					  ->E.R2.fmap(r => GenericDist_Types.PointSet(r))
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
let pointwiseCombinationFloat = (
 | 
					let pointwiseCombinationFloat = (
 | 
				
			||||||
| 
						 | 
					@ -232,7 +242,7 @@ let pointwiseCombinationFloat = (
 | 
				
			||||||
      )
 | 
					      )
 | 
				
			||||||
    })
 | 
					    })
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
  m->E.R2.fmap(r => #PointSet(r))
 | 
					  m->E.R2.fmap(r => GenericDist_Types.PointSet(r))
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
//Note: The result should always cumulatively sum to 1. This would be good to test.
 | 
					//Note: The result should always cumulatively sum to 1. This would be good to test.
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -71,14 +71,14 @@ let rec run = (~env, functionCallInfo: functionCallInfo): outputType => {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  let toPointSetFn = r => {
 | 
					  let toPointSetFn = r => {
 | 
				
			||||||
    switch reCall(~functionCallInfo=#fromDist(#toDist(#toPointSet), r), ()) {
 | 
					    switch reCall(~functionCallInfo=#fromDist(#toDist(#toPointSet), r), ()) {
 | 
				
			||||||
    | Dist(#PointSet(p)) => Ok(p)
 | 
					    | Dist(PointSet(p)) => Ok(p)
 | 
				
			||||||
    | e => Error(OutputLocal.toErrorOrUnreachable(e))
 | 
					    | e => Error(OutputLocal.toErrorOrUnreachable(e))
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  let toSampleSetFn = r => {
 | 
					  let toSampleSetFn = r => {
 | 
				
			||||||
    switch reCall(~functionCallInfo=#fromDist(#toDist(#toSampleSet(sampleCount)), r), ()) {
 | 
					    switch reCall(~functionCallInfo=#fromDist(#toDist(#toSampleSet(sampleCount)), r), ()) {
 | 
				
			||||||
    | Dist(#SampleSet(p)) => Ok(p)
 | 
					    | Dist(SampleSet(p)) => Ok(p)
 | 
				
			||||||
    | e => Error(OutputLocal.toErrorOrUnreachable(e))
 | 
					    | e => Error(OutputLocal.toErrorOrUnreachable(e))
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
| 
						 | 
					@ -114,10 +114,10 @@ let rec run = (~env, functionCallInfo: functionCallInfo): outputType => {
 | 
				
			||||||
    | #toDist(#toPointSet) =>
 | 
					    | #toDist(#toPointSet) =>
 | 
				
			||||||
      dist
 | 
					      dist
 | 
				
			||||||
      ->GenericDist.toPointSet(~xyPointLength, ~sampleCount)
 | 
					      ->GenericDist.toPointSet(~xyPointLength, ~sampleCount)
 | 
				
			||||||
      ->E.R2.fmap(r => Dist(#PointSet(r)))
 | 
					      ->E.R2.fmap(r => Dist(PointSet(r)))
 | 
				
			||||||
      ->OutputLocal.fromResult
 | 
					      ->OutputLocal.fromResult
 | 
				
			||||||
    | #toDist(#toSampleSet(n)) =>
 | 
					    | #toDist(#toSampleSet(n)) =>
 | 
				
			||||||
      dist->GenericDist.sampleN(n)->E.R2.fmap(r => Dist(#SampleSet(r)))->OutputLocal.fromResult
 | 
					      dist->GenericDist.sampleN(n)->E.R2.fmap(r => Dist(SampleSet(r)))->OutputLocal.fromResult
 | 
				
			||||||
    | #toDistCombination(#Algebraic, _, #Float(_)) => GenDistError(NotYetImplemented)
 | 
					    | #toDistCombination(#Algebraic, _, #Float(_)) => GenDistError(NotYetImplemented)
 | 
				
			||||||
    | #toDistCombination(#Algebraic, arithmeticOperation, #Dist(t2)) =>
 | 
					    | #toDistCombination(#Algebraic, arithmeticOperation, #Dist(t2)) =>
 | 
				
			||||||
      dist
 | 
					      dist
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -15,7 +15,11 @@ let runFromDist: (
 | 
				
			||||||
  ~functionCallInfo: GenericDist_Types.Operation.fromDist,
 | 
					  ~functionCallInfo: GenericDist_Types.Operation.fromDist,
 | 
				
			||||||
  GenericDist_Types.genericDist,
 | 
					  GenericDist_Types.genericDist,
 | 
				
			||||||
) => outputType
 | 
					) => outputType
 | 
				
			||||||
let runFromFloat: (~env: env, ~functionCallInfo: GenericDist_Types.Operation.fromDist, float) => outputType
 | 
					let runFromFloat: (
 | 
				
			||||||
 | 
					  ~env: env,
 | 
				
			||||||
 | 
					  ~functionCallInfo: GenericDist_Types.Operation.fromDist,
 | 
				
			||||||
 | 
					  float,
 | 
				
			||||||
 | 
					) => outputType
 | 
				
			||||||
 | 
					
 | 
				
			||||||
module Output: {
 | 
					module Output: {
 | 
				
			||||||
  type t = outputType
 | 
					  type t = outputType
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -1,8 +1,7 @@
 | 
				
			||||||
type genericDist = [
 | 
					type genericDist =
 | 
				
			||||||
  | #PointSet(PointSetTypes.pointSetDist)
 | 
					  | PointSet(PointSetTypes.pointSetDist)
 | 
				
			||||||
  | #SampleSet(array<float>)
 | 
					  | SampleSet(array<float>)
 | 
				
			||||||
  | #Symbolic(SymbolicDistTypes.symbolicDist)
 | 
					  | Symbolic(SymbolicDistTypes.symbolicDist)
 | 
				
			||||||
]
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
type error =
 | 
					type error =
 | 
				
			||||||
  | NotYetImplemented
 | 
					  | NotYetImplemented
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue
	
	Block a user