Added genType to SampleSetDist to make pass tests, other minor fixes
This commit is contained in:
parent
9ad73fe69b
commit
4338f482ef
|
@ -128,7 +128,10 @@ let rec run = (~env, functionCallInfo: functionCallInfo): outputType => {
|
||||||
->E.R2.fmap(r => Dist(r))
|
->E.R2.fmap(r => Dist(r))
|
||||||
->OutputLocal.fromResult
|
->OutputLocal.fromResult
|
||||||
| ToDist(ToSampleSet(n)) =>
|
| ToDist(ToSampleSet(n)) =>
|
||||||
dist->GenericDist.toSampleSetDist(n)->E.R2.fmap(r => Dist(SampleSet(r)))->OutputLocal.fromResult
|
dist
|
||||||
|
->GenericDist.toSampleSetDist(n)
|
||||||
|
->E.R2.fmap(r => Dist(SampleSet(r)))
|
||||||
|
->OutputLocal.fromResult
|
||||||
| ToDist(ToPointSet) =>
|
| ToDist(ToPointSet) =>
|
||||||
dist
|
dist
|
||||||
->GenericDist.toPointSet(~xyPointLength, ~sampleCount, ())
|
->GenericDist.toPointSet(~xyPointLength, ~sampleCount, ())
|
||||||
|
@ -204,7 +207,8 @@ module Constructors = {
|
||||||
C.truncate(dist, leftCutoff, rightCutoff)->run(~env)->toDistR
|
C.truncate(dist, leftCutoff, rightCutoff)->run(~env)->toDistR
|
||||||
let inspect = (~env, dist) => C.inspect(dist)->run(~env)->toDistR
|
let inspect = (~env, dist) => C.inspect(dist)->run(~env)->toDistR
|
||||||
let toString = (~env, dist) => C.toString(dist)->run(~env)->toStringR
|
let toString = (~env, dist) => C.toString(dist)->run(~env)->toStringR
|
||||||
let toSparkline = (~env, dist, bucketCount) => C.toSparkline(dist, bucketCount)->run(~env)->toStringR
|
let toSparkline = (~env, dist, bucketCount) =>
|
||||||
|
C.toSparkline(dist, bucketCount)->run(~env)->toStringR
|
||||||
let algebraicAdd = (~env, dist1, dist2) => C.algebraicAdd(dist1, dist2)->run(~env)->toDistR
|
let algebraicAdd = (~env, dist1, dist2) => C.algebraicAdd(dist1, dist2)->run(~env)->toDistR
|
||||||
let algebraicMultiply = (~env, dist1, dist2) =>
|
let algebraicMultiply = (~env, dist1, dist2) =>
|
||||||
C.algebraicMultiply(dist1, dist2)->run(~env)->toDistR
|
C.algebraicMultiply(dist1, dist2)->run(~env)->toDistR
|
||||||
|
@ -213,8 +217,7 @@ module Constructors = {
|
||||||
C.algebraicSubtract(dist1, dist2)->run(~env)->toDistR
|
C.algebraicSubtract(dist1, dist2)->run(~env)->toDistR
|
||||||
let algebraicLogarithm = (~env, dist1, dist2) =>
|
let algebraicLogarithm = (~env, dist1, dist2) =>
|
||||||
C.algebraicLogarithm(dist1, dist2)->run(~env)->toDistR
|
C.algebraicLogarithm(dist1, dist2)->run(~env)->toDistR
|
||||||
let algebraicPower = (~env, dist1, dist2) =>
|
let algebraicPower = (~env, dist1, dist2) => C.algebraicPower(dist1, dist2)->run(~env)->toDistR
|
||||||
C.algebraicPower(dist1, dist2)->run(~env)->toDistR
|
|
||||||
let pointwiseAdd = (~env, dist1, dist2) => C.pointwiseAdd(dist1, dist2)->run(~env)->toDistR
|
let pointwiseAdd = (~env, dist1, dist2) => C.pointwiseAdd(dist1, dist2)->run(~env)->toDistR
|
||||||
let pointwiseMultiply = (~env, dist1, dist2) =>
|
let pointwiseMultiply = (~env, dist1, dist2) =>
|
||||||
C.pointwiseMultiply(dist1, dist2)->run(~env)->toDistR
|
C.pointwiseMultiply(dist1, dist2)->run(~env)->toDistR
|
||||||
|
@ -223,6 +226,5 @@ module Constructors = {
|
||||||
C.pointwiseSubtract(dist1, dist2)->run(~env)->toDistR
|
C.pointwiseSubtract(dist1, dist2)->run(~env)->toDistR
|
||||||
let pointwiseLogarithm = (~env, dist1, dist2) =>
|
let pointwiseLogarithm = (~env, dist1, dist2) =>
|
||||||
C.pointwiseLogarithm(dist1, dist2)->run(~env)->toDistR
|
C.pointwiseLogarithm(dist1, dist2)->run(~env)->toDistR
|
||||||
let pointwisePower = (~env, dist1, dist2) =>
|
let pointwisePower = (~env, dist1, dist2) => C.pointwisePower(dist1, dist2)->run(~env)->toDistR
|
||||||
C.pointwisePower(dist1, dist2)->run(~env)->toDistR
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -5,14 +5,16 @@ type toPointSetFn = t => result<PointSetTypes.pointSetDist, error>
|
||||||
type toSampleSetFn = t => result<SampleSetDist.t, error>
|
type toSampleSetFn = t => result<SampleSetDist.t, error>
|
||||||
type scaleMultiplyFn = (t, float) => result<t, error>
|
type scaleMultiplyFn = (t, float) => result<t, error>
|
||||||
type pointwiseAddFn = (t, t) => result<t, error>
|
type pointwiseAddFn = (t, t) => result<t, error>
|
||||||
|
|
||||||
let sampleN = (t: t, n) =>
|
let sampleN = (t: t, n) =>
|
||||||
switch t {
|
switch t {
|
||||||
| PointSet(r) => Ok(PointSetDist.sampleNRendered(n, r))
|
| PointSet(r) => PointSetDist.sampleNRendered(n, r)
|
||||||
| Symbolic(r) => Ok(SymbolicDist.T.sampleN(n, r))
|
| Symbolic(r) => SymbolicDist.T.sampleN(n, r)
|
||||||
| SampleSet(r) => Ok(SampleSetDist.sampleN(r, n))
|
| SampleSet(r) => SampleSetDist.sampleN(r, n)
|
||||||
}
|
}
|
||||||
|
|
||||||
let toSampleSetDist = (t: t, n) =>
|
let toSampleSetDist = (t: t, n) =>
|
||||||
sampleN(t, n)->E.R.bind(SampleSetDist.make)->GenericDist_Types.Error.resultStringToResultError
|
SampleSetDist.make(sampleN(t, n))->GenericDist_Types.Error.resultStringToResultError
|
||||||
|
|
||||||
let fromFloat = (f: float): t => Symbolic(SymbolicDist.Float.make(f))
|
let fromFloat = (f: float): t => Symbolic(SymbolicDist.Float.make(f))
|
||||||
|
|
||||||
|
@ -72,7 +74,6 @@ let toPointSet = (
|
||||||
pointSetDistLength: xyPointLength,
|
pointSetDistLength: xyPointLength,
|
||||||
kernelWidth: None,
|
kernelWidth: None,
|
||||||
},
|
},
|
||||||
(),
|
|
||||||
)->GenericDist_Types.Error.resultStringToResultError
|
)->GenericDist_Types.Error.resultStringToResultError
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -162,14 +163,12 @@ module AlgebraicCombination = {
|
||||||
t1: t,
|
t1: t,
|
||||||
t2: t,
|
t2: t,
|
||||||
) => {
|
) => {
|
||||||
let arithmeticOperation = Operation.Algebraic.toFn(arithmeticOperation)
|
let fn = Operation.Algebraic.toFn(arithmeticOperation)
|
||||||
E.R.merge(toSampleSet(t1), toSampleSet(t2))->E.R.bind(((a, b)) => {
|
E.R.merge(toSampleSet(t1), toSampleSet(t2))
|
||||||
SampleSetDist.map2(
|
->E.R.bind(((t1, t2)) => {
|
||||||
~fn=arithmeticOperation,
|
SampleSetDist.map2(~fn, ~t1, ~t2)->GenericDist_Types.Error.resultStringToResultError
|
||||||
~t1=a,
|
|
||||||
~t2=b,
|
|
||||||
)->GenericDist_Types.Error.resultStringToResultError
|
|
||||||
})
|
})
|
||||||
|
->E.R2.fmap(r => GenericDist_Types.SampleSet(r))
|
||||||
}
|
}
|
||||||
|
|
||||||
//I'm (Ozzie) really just guessing here, very little idea what's best
|
//I'm (Ozzie) really just guessing here, very little idea what's best
|
||||||
|
@ -200,15 +199,7 @@ module AlgebraicCombination = {
|
||||||
| Some(Error(e)) => Error(Other(e))
|
| Some(Error(e)) => Error(Other(e))
|
||||||
| None =>
|
| None =>
|
||||||
switch chooseConvolutionOrMonteCarlo(t1, t2) {
|
switch chooseConvolutionOrMonteCarlo(t1, t2) {
|
||||||
| #CalculateWithMonteCarlo => {
|
| #CalculateWithMonteCarlo => runMonteCarlo(toSampleSetFn, arithmeticOperation, t1, t2)
|
||||||
let sampleSetDist: result<SampleSetDist.t, error> = runMonteCarlo(
|
|
||||||
toSampleSetFn,
|
|
||||||
arithmeticOperation,
|
|
||||||
t1,
|
|
||||||
t2,
|
|
||||||
)
|
|
||||||
sampleSetDist->E.R2.fmap(r => GenericDist_Types.SampleSet(r))
|
|
||||||
}
|
|
||||||
| #CalculateWithConvolution =>
|
| #CalculateWithConvolution =>
|
||||||
runConvolution(
|
runConvolution(
|
||||||
toPointSetFn,
|
toPointSetFn,
|
||||||
|
@ -274,7 +265,7 @@ let mixture = (
|
||||||
~pointwiseAddFn: pointwiseAddFn,
|
~pointwiseAddFn: pointwiseAddFn,
|
||||||
) => {
|
) => {
|
||||||
if E.A.length(values) == 0 {
|
if E.A.length(values) == 0 {
|
||||||
Error(GenericDist_Types.Other("mixture must have at least 1 element"))
|
Error(GenericDist_Types.Other("Mixture error: mixture must have at least 1 element"))
|
||||||
} else {
|
} else {
|
||||||
let totalWeight = values->E.A2.fmap(E.Tuple2.second)->E.A.Floats.sum
|
let totalWeight = values->E.A2.fmap(E.Tuple2.second)->E.A.Floats.sum
|
||||||
let properlyWeightedValues =
|
let properlyWeightedValues =
|
||||||
|
|
|
@ -5,7 +5,8 @@ type toSampleSetFn = t => result<SampleSetDist.t, error>
|
||||||
type scaleMultiplyFn = (t, float) => result<t, error>
|
type scaleMultiplyFn = (t, float) => result<t, error>
|
||||||
type pointwiseAddFn = (t, t) => result<t, error>
|
type pointwiseAddFn = (t, t) => result<t, error>
|
||||||
|
|
||||||
let sampleN: (t, int) => result<array<float>, error>
|
let sampleN: (t, int) => array<float>
|
||||||
|
|
||||||
let toSampleSetDist: (t, int) => Belt.Result.t<QuriSquiggleLang.SampleSetDist.t, error>
|
let toSampleSetDist: (t, int) => Belt.Result.t<QuriSquiggleLang.SampleSetDist.t, error>
|
||||||
|
|
||||||
let fromFloat: float => t
|
let fromFloat: float => t
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
module T: {
|
module T: {
|
||||||
|
@genType
|
||||||
type t
|
type t
|
||||||
let make: array<float> => result<t, string>
|
let make: array<float> => result<t, string>
|
||||||
let get: t => array<float>
|
let get: t => array<float>
|
||||||
|
@ -18,7 +19,7 @@ include T
|
||||||
let length = (t: t) => get(t) |> E.A.length
|
let length = (t: t) => get(t) |> E.A.length
|
||||||
|
|
||||||
// TODO: Refactor to get error in the toPointSetDist function, instead of adding at very end.
|
// TODO: Refactor to get error in the toPointSetDist function, instead of adding at very end.
|
||||||
let toPointSetDist = (~samples: t, ~samplingInputs: SamplingInputs.samplingInputs, ()): result<
|
let toPointSetDist = (~samples: t, ~samplingInputs: SamplingInputs.samplingInputs): result<
|
||||||
PointSetTypes.pointSetDist,
|
PointSetTypes.pointSetDist,
|
||||||
string,
|
string,
|
||||||
> =>
|
> =>
|
||||||
|
|
|
@ -225,7 +225,7 @@ module SamplingDistribution = {
|
||||||
let pointSetDist =
|
let pointSetDist =
|
||||||
sampleSetDist
|
sampleSetDist
|
||||||
-> E.R.bind(r =>
|
-> E.R.bind(r =>
|
||||||
SampleSetDist.toPointSetDist(~samplingInputs=evaluationParams.samplingInputs, ~samples=r, ()));
|
SampleSetDist.toPointSetDist(~samplingInputs=evaluationParams.samplingInputs, ~samples=r));
|
||||||
pointSetDist |> E.R.fmap(r => #Normalize(#RenderedDist(r)))
|
pointSetDist |> E.R.fmap(r => #Normalize(#RenderedDist(r)))
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue
Block a user