Added genType to SampleSetDist to make pass tests, other minor fixes

This commit is contained in:
Ozzie Gooen 2022-04-09 21:24:44 -04:00
parent 9ad73fe69b
commit 4338f482ef
5 changed files with 26 additions and 31 deletions

View File

@ -128,7 +128,10 @@ let rec run = (~env, functionCallInfo: functionCallInfo): outputType => {
->E.R2.fmap(r => Dist(r))
->OutputLocal.fromResult
| ToDist(ToSampleSet(n)) =>
dist->GenericDist.toSampleSetDist(n)->E.R2.fmap(r => Dist(SampleSet(r)))->OutputLocal.fromResult
dist
->GenericDist.toSampleSetDist(n)
->E.R2.fmap(r => Dist(SampleSet(r)))
->OutputLocal.fromResult
| ToDist(ToPointSet) =>
dist
->GenericDist.toPointSet(~xyPointLength, ~sampleCount, ())
@ -204,7 +207,8 @@ module Constructors = {
C.truncate(dist, leftCutoff, rightCutoff)->run(~env)->toDistR
let inspect = (~env, dist) => C.inspect(dist)->run(~env)->toDistR
let toString = (~env, dist) => C.toString(dist)->run(~env)->toStringR
let toSparkline = (~env, dist, bucketCount) => C.toSparkline(dist, bucketCount)->run(~env)->toStringR
let toSparkline = (~env, dist, bucketCount) =>
C.toSparkline(dist, bucketCount)->run(~env)->toStringR
let algebraicAdd = (~env, dist1, dist2) => C.algebraicAdd(dist1, dist2)->run(~env)->toDistR
let algebraicMultiply = (~env, dist1, dist2) =>
C.algebraicMultiply(dist1, dist2)->run(~env)->toDistR
@ -213,8 +217,7 @@ module Constructors = {
C.algebraicSubtract(dist1, dist2)->run(~env)->toDistR
let algebraicLogarithm = (~env, dist1, dist2) =>
C.algebraicLogarithm(dist1, dist2)->run(~env)->toDistR
let algebraicPower = (~env, dist1, dist2) =>
C.algebraicPower(dist1, dist2)->run(~env)->toDistR
let algebraicPower = (~env, dist1, dist2) => C.algebraicPower(dist1, dist2)->run(~env)->toDistR
let pointwiseAdd = (~env, dist1, dist2) => C.pointwiseAdd(dist1, dist2)->run(~env)->toDistR
let pointwiseMultiply = (~env, dist1, dist2) =>
C.pointwiseMultiply(dist1, dist2)->run(~env)->toDistR
@ -223,6 +226,5 @@ module Constructors = {
C.pointwiseSubtract(dist1, dist2)->run(~env)->toDistR
let pointwiseLogarithm = (~env, dist1, dist2) =>
C.pointwiseLogarithm(dist1, dist2)->run(~env)->toDistR
let pointwisePower = (~env, dist1, dist2) =>
C.pointwisePower(dist1, dist2)->run(~env)->toDistR
let pointwisePower = (~env, dist1, dist2) => C.pointwisePower(dist1, dist2)->run(~env)->toDistR
}

View File

@ -5,14 +5,16 @@ type toPointSetFn = t => result<PointSetTypes.pointSetDist, error>
type toSampleSetFn = t => result<SampleSetDist.t, error>
type scaleMultiplyFn = (t, float) => result<t, error>
type pointwiseAddFn = (t, t) => result<t, error>
let sampleN = (t: t, n) =>
switch t {
| PointSet(r) => Ok(PointSetDist.sampleNRendered(n, r))
| Symbolic(r) => Ok(SymbolicDist.T.sampleN(n, r))
| SampleSet(r) => Ok(SampleSetDist.sampleN(r, n))
| PointSet(r) => PointSetDist.sampleNRendered(n, r)
| Symbolic(r) => SymbolicDist.T.sampleN(n, r)
| SampleSet(r) => SampleSetDist.sampleN(r, n)
}
let toSampleSetDist = (t: t, n) =>
sampleN(t, n)->E.R.bind(SampleSetDist.make)->GenericDist_Types.Error.resultStringToResultError
SampleSetDist.make(sampleN(t, n))->GenericDist_Types.Error.resultStringToResultError
let fromFloat = (f: float): t => Symbolic(SymbolicDist.Float.make(f))
@ -72,7 +74,6 @@ let toPointSet = (
pointSetDistLength: xyPointLength,
kernelWidth: None,
},
(),
)->GenericDist_Types.Error.resultStringToResultError
}
}
@ -162,14 +163,12 @@ module AlgebraicCombination = {
t1: t,
t2: t,
) => {
let arithmeticOperation = Operation.Algebraic.toFn(arithmeticOperation)
E.R.merge(toSampleSet(t1), toSampleSet(t2))->E.R.bind(((a, b)) => {
SampleSetDist.map2(
~fn=arithmeticOperation,
~t1=a,
~t2=b,
)->GenericDist_Types.Error.resultStringToResultError
let fn = Operation.Algebraic.toFn(arithmeticOperation)
E.R.merge(toSampleSet(t1), toSampleSet(t2))
->E.R.bind(((t1, t2)) => {
SampleSetDist.map2(~fn, ~t1, ~t2)->GenericDist_Types.Error.resultStringToResultError
})
->E.R2.fmap(r => GenericDist_Types.SampleSet(r))
}
//I'm (Ozzie) really just guessing here, very little idea what's best
@ -200,15 +199,7 @@ module AlgebraicCombination = {
| Some(Error(e)) => Error(Other(e))
| None =>
switch chooseConvolutionOrMonteCarlo(t1, t2) {
| #CalculateWithMonteCarlo => {
let sampleSetDist: result<SampleSetDist.t, error> = runMonteCarlo(
toSampleSetFn,
arithmeticOperation,
t1,
t2,
)
sampleSetDist->E.R2.fmap(r => GenericDist_Types.SampleSet(r))
}
| #CalculateWithMonteCarlo => runMonteCarlo(toSampleSetFn, arithmeticOperation, t1, t2)
| #CalculateWithConvolution =>
runConvolution(
toPointSetFn,
@ -274,7 +265,7 @@ let mixture = (
~pointwiseAddFn: pointwiseAddFn,
) => {
if E.A.length(values) == 0 {
Error(GenericDist_Types.Other("mixture must have at least 1 element"))
Error(GenericDist_Types.Other("Mixture error: mixture must have at least 1 element"))
} else {
let totalWeight = values->E.A2.fmap(E.Tuple2.second)->E.A.Floats.sum
let properlyWeightedValues =

View File

@ -5,7 +5,8 @@ type toSampleSetFn = t => result<SampleSetDist.t, error>
type scaleMultiplyFn = (t, float) => result<t, error>
type pointwiseAddFn = (t, t) => result<t, error>
let sampleN: (t, int) => result<array<float>, error>
let sampleN: (t, int) => array<float>
let toSampleSetDist: (t, int) => Belt.Result.t<QuriSquiggleLang.SampleSetDist.t, error>
let fromFloat: float => t

View File

@ -1,4 +1,5 @@
module T: {
@genType
type t
let make: array<float> => result<t, string>
let get: t => array<float>
@ -18,7 +19,7 @@ include T
let length = (t: t) => get(t) |> E.A.length
// TODO: Refactor to get error in the toPointSetDist function, instead of adding at very end.
let toPointSetDist = (~samples: t, ~samplingInputs: SamplingInputs.samplingInputs, ()): result<
let toPointSetDist = (~samples: t, ~samplingInputs: SamplingInputs.samplingInputs): result<
PointSetTypes.pointSetDist,
string,
> =>

View File

@ -225,7 +225,7 @@ module SamplingDistribution = {
let pointSetDist =
sampleSetDist
-> E.R.bind(r =>
SampleSetDist.toPointSetDist(~samplingInputs=evaluationParams.samplingInputs, ~samples=r, ()));
SampleSetDist.toPointSetDist(~samplingInputs=evaluationParams.samplingInputs, ~samples=r));
pointSetDist |> E.R.fmap(r => #Normalize(#RenderedDist(r)))
})
}