Added genType to SampleSetDist to make pass tests, other minor fixes
This commit is contained in:
		
							parent
							
								
									9ad73fe69b
								
							
						
					
					
						commit
						4338f482ef
					
				|  | @ -128,7 +128,10 @@ let rec run = (~env, functionCallInfo: functionCallInfo): outputType => { | |||
|       ->E.R2.fmap(r => Dist(r)) | ||||
|       ->OutputLocal.fromResult | ||||
|     | ToDist(ToSampleSet(n)) => | ||||
|       dist->GenericDist.toSampleSetDist(n)->E.R2.fmap(r => Dist(SampleSet(r)))->OutputLocal.fromResult | ||||
|       dist | ||||
|       ->GenericDist.toSampleSetDist(n) | ||||
|       ->E.R2.fmap(r => Dist(SampleSet(r))) | ||||
|       ->OutputLocal.fromResult | ||||
|     | ToDist(ToPointSet) => | ||||
|       dist | ||||
|       ->GenericDist.toPointSet(~xyPointLength, ~sampleCount, ()) | ||||
|  | @ -204,7 +207,8 @@ module Constructors = { | |||
|     C.truncate(dist, leftCutoff, rightCutoff)->run(~env)->toDistR | ||||
|   let inspect = (~env, dist) => C.inspect(dist)->run(~env)->toDistR | ||||
|   let toString = (~env, dist) => C.toString(dist)->run(~env)->toStringR | ||||
|   let toSparkline = (~env, dist, bucketCount) => C.toSparkline(dist, bucketCount)->run(~env)->toStringR | ||||
|   let toSparkline = (~env, dist, bucketCount) => | ||||
|     C.toSparkline(dist, bucketCount)->run(~env)->toStringR | ||||
|   let algebraicAdd = (~env, dist1, dist2) => C.algebraicAdd(dist1, dist2)->run(~env)->toDistR | ||||
|   let algebraicMultiply = (~env, dist1, dist2) => | ||||
|     C.algebraicMultiply(dist1, dist2)->run(~env)->toDistR | ||||
|  | @ -213,8 +217,7 @@ module Constructors = { | |||
|     C.algebraicSubtract(dist1, dist2)->run(~env)->toDistR | ||||
|   let algebraicLogarithm = (~env, dist1, dist2) => | ||||
|     C.algebraicLogarithm(dist1, dist2)->run(~env)->toDistR | ||||
|   let algebraicPower = (~env, dist1, dist2) => | ||||
|     C.algebraicPower(dist1, dist2)->run(~env)->toDistR | ||||
|   let algebraicPower = (~env, dist1, dist2) => C.algebraicPower(dist1, dist2)->run(~env)->toDistR | ||||
|   let pointwiseAdd = (~env, dist1, dist2) => C.pointwiseAdd(dist1, dist2)->run(~env)->toDistR | ||||
|   let pointwiseMultiply = (~env, dist1, dist2) => | ||||
|     C.pointwiseMultiply(dist1, dist2)->run(~env)->toDistR | ||||
|  | @ -223,6 +226,5 @@ module Constructors = { | |||
|     C.pointwiseSubtract(dist1, dist2)->run(~env)->toDistR | ||||
|   let pointwiseLogarithm = (~env, dist1, dist2) => | ||||
|     C.pointwiseLogarithm(dist1, dist2)->run(~env)->toDistR | ||||
|   let pointwisePower = (~env, dist1, dist2) => | ||||
|     C.pointwisePower(dist1, dist2)->run(~env)->toDistR | ||||
|   let pointwisePower = (~env, dist1, dist2) => C.pointwisePower(dist1, dist2)->run(~env)->toDistR | ||||
| } | ||||
|  |  | |||
|  | @ -5,14 +5,16 @@ type toPointSetFn = t => result<PointSetTypes.pointSetDist, error> | |||
| type toSampleSetFn = t => result<SampleSetDist.t, error> | ||||
| type scaleMultiplyFn = (t, float) => result<t, error> | ||||
| type pointwiseAddFn = (t, t) => result<t, error> | ||||
| 
 | ||||
| let sampleN = (t: t, n) => | ||||
|   switch t { | ||||
|   | PointSet(r) => Ok(PointSetDist.sampleNRendered(n, r)) | ||||
|   | Symbolic(r) => Ok(SymbolicDist.T.sampleN(n, r)) | ||||
|   | SampleSet(r) => Ok(SampleSetDist.sampleN(r, n)) | ||||
|   | PointSet(r) => PointSetDist.sampleNRendered(n, r) | ||||
|   | Symbolic(r) => SymbolicDist.T.sampleN(n, r) | ||||
|   | SampleSet(r) => SampleSetDist.sampleN(r, n) | ||||
|   } | ||||
| 
 | ||||
| let toSampleSetDist = (t: t, n) => | ||||
|   sampleN(t, n)->E.R.bind(SampleSetDist.make)->GenericDist_Types.Error.resultStringToResultError | ||||
|   SampleSetDist.make(sampleN(t, n))->GenericDist_Types.Error.resultStringToResultError | ||||
| 
 | ||||
| let fromFloat = (f: float): t => Symbolic(SymbolicDist.Float.make(f)) | ||||
| 
 | ||||
|  | @ -72,7 +74,6 @@ let toPointSet = ( | |||
|         pointSetDistLength: xyPointLength, | ||||
|         kernelWidth: None, | ||||
|       }, | ||||
|       (), | ||||
|     )->GenericDist_Types.Error.resultStringToResultError | ||||
|   } | ||||
| } | ||||
|  | @ -162,14 +163,12 @@ module AlgebraicCombination = { | |||
|     t1: t, | ||||
|     t2: t, | ||||
|   ) => { | ||||
|     let arithmeticOperation = Operation.Algebraic.toFn(arithmeticOperation) | ||||
|     E.R.merge(toSampleSet(t1), toSampleSet(t2))->E.R.bind(((a, b)) => { | ||||
|       SampleSetDist.map2( | ||||
|         ~fn=arithmeticOperation, | ||||
|         ~t1=a, | ||||
|         ~t2=b, | ||||
|       )->GenericDist_Types.Error.resultStringToResultError | ||||
|     let fn = Operation.Algebraic.toFn(arithmeticOperation) | ||||
|     E.R.merge(toSampleSet(t1), toSampleSet(t2)) | ||||
|     ->E.R.bind(((t1, t2)) => { | ||||
|       SampleSetDist.map2(~fn, ~t1, ~t2)->GenericDist_Types.Error.resultStringToResultError | ||||
|     }) | ||||
|     ->E.R2.fmap(r => GenericDist_Types.SampleSet(r)) | ||||
|   } | ||||
| 
 | ||||
|   //I'm (Ozzie) really just guessing here, very little idea what's best | ||||
|  | @ -200,15 +199,7 @@ module AlgebraicCombination = { | |||
|     | Some(Error(e)) => Error(Other(e)) | ||||
|     | None => | ||||
|       switch chooseConvolutionOrMonteCarlo(t1, t2) { | ||||
|       | #CalculateWithMonteCarlo => { | ||||
|           let sampleSetDist: result<SampleSetDist.t, error> = runMonteCarlo( | ||||
|             toSampleSetFn, | ||||
|             arithmeticOperation, | ||||
|             t1, | ||||
|             t2, | ||||
|           ) | ||||
|           sampleSetDist->E.R2.fmap(r => GenericDist_Types.SampleSet(r)) | ||||
|         } | ||||
|       | #CalculateWithMonteCarlo => runMonteCarlo(toSampleSetFn, arithmeticOperation, t1, t2) | ||||
|       | #CalculateWithConvolution => | ||||
|         runConvolution( | ||||
|           toPointSetFn, | ||||
|  | @ -274,7 +265,7 @@ let mixture = ( | |||
|   ~pointwiseAddFn: pointwiseAddFn, | ||||
| ) => { | ||||
|   if E.A.length(values) == 0 { | ||||
|     Error(GenericDist_Types.Other("mixture must have at least 1 element")) | ||||
|     Error(GenericDist_Types.Other("Mixture error: mixture must have at least 1 element")) | ||||
|   } else { | ||||
|     let totalWeight = values->E.A2.fmap(E.Tuple2.second)->E.A.Floats.sum | ||||
|     let properlyWeightedValues = | ||||
|  |  | |||
|  | @ -5,7 +5,8 @@ type toSampleSetFn = t => result<SampleSetDist.t, error> | |||
| type scaleMultiplyFn = (t, float) => result<t, error> | ||||
| type pointwiseAddFn = (t, t) => result<t, error> | ||||
| 
 | ||||
| let sampleN: (t, int) => result<array<float>, error> | ||||
| let sampleN: (t, int) => array<float> | ||||
| 
 | ||||
| let toSampleSetDist: (t, int) => Belt.Result.t<QuriSquiggleLang.SampleSetDist.t, error> | ||||
| 
 | ||||
| let fromFloat: float => t | ||||
|  |  | |||
|  | @ -1,4 +1,5 @@ | |||
| module T: { | ||||
|   @genType | ||||
|   type t | ||||
|   let make: array<float> => result<t, string> | ||||
|   let get: t => array<float> | ||||
|  | @ -18,7 +19,7 @@ include T | |||
| let length = (t: t) => get(t) |> E.A.length | ||||
| 
 | ||||
| // TODO: Refactor to get error in the toPointSetDist function, instead of adding at very end. | ||||
| let toPointSetDist = (~samples: t, ~samplingInputs: SamplingInputs.samplingInputs, ()): result< | ||||
| let toPointSetDist = (~samples: t, ~samplingInputs: SamplingInputs.samplingInputs): result< | ||||
|   PointSetTypes.pointSetDist, | ||||
|   string, | ||||
| > => | ||||
|  |  | |||
|  | @ -225,7 +225,7 @@ module SamplingDistribution = { | |||
|       let pointSetDist =  | ||||
|         sampleSetDist | ||||
|         -> E.R.bind(r => | ||||
|           SampleSetDist.toPointSetDist(~samplingInputs=evaluationParams.samplingInputs, ~samples=r, ())); | ||||
|           SampleSetDist.toPointSetDist(~samplingInputs=evaluationParams.samplingInputs, ~samples=r)); | ||||
|       pointSetDist |> E.R.fmap(r => #Normalize(#RenderedDist(r))) | ||||
|     }) | ||||
|   } | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue
	
	Block a user