Merge pull request #1002 from quantified-uncertainty/sampleset-mixture
Sampleset mixture
This commit is contained in:
commit
e6d543daef
219
.github/workflows/ci.yml
vendored
219
.github/workflows/ci.yml
vendored
|
@ -49,26 +49,26 @@ jobs:
|
||||||
with:
|
with:
|
||||||
paths: '["packages/cli/**"]'
|
paths: '["packages/cli/**"]'
|
||||||
|
|
||||||
# lang-lint:
|
# lang-lint:
|
||||||
# name: Language lint
|
# name: Language lint
|
||||||
# runs-on: ubuntu-latest
|
# runs-on: ubuntu-latest
|
||||||
# needs: pre_check
|
# needs: pre_check
|
||||||
# if: ${{ needs.pre_check.outputs.should_skip_lang != 'true' }}
|
# if: ${{ needs.pre_check.outputs.should_skip_lang != 'true' }}
|
||||||
# defaults:
|
# defaults:
|
||||||
# run:
|
# run:
|
||||||
# shell: bash
|
# shell: bash
|
||||||
# working-directory: packages/squiggle-lang
|
# working-directory: packages/squiggle-lang
|
||||||
# steps:
|
# steps:
|
||||||
# - uses: actions/checkout@v3
|
# - uses: actions/checkout@v3
|
||||||
# - name: Install Dependencies
|
# - name: Install Dependencies
|
||||||
# run: cd ../../ && yarn
|
# run: cd ../../ && yarn
|
||||||
# - name: Check rescript lint
|
# - name: Check rescript lint
|
||||||
# run: yarn lint:rescript
|
# run: yarn lint:rescript
|
||||||
# - name: Check javascript, typescript, and markdown lint
|
# - name: Check javascript, typescript, and markdown lint
|
||||||
# uses: creyD/prettier_action@v4.2
|
# uses: creyD/prettier_action@v4.2
|
||||||
# with:
|
# with:
|
||||||
# dry: true
|
# dry: true
|
||||||
# prettier_options: --check packages/squiggle-lang
|
# prettier_options: --check packages/squiggle-lang
|
||||||
|
|
||||||
lang-build-test-bundle:
|
lang-build-test-bundle:
|
||||||
name: Language build, test, and bundle
|
name: Language build, test, and bundle
|
||||||
|
@ -98,96 +98,96 @@ jobs:
|
||||||
- name: Upload typescript coverage report
|
- name: Upload typescript coverage report
|
||||||
run: yarn coverage:ts:ci
|
run: yarn coverage:ts:ci
|
||||||
|
|
||||||
# components-lint:
|
# components-lint:
|
||||||
# name: Components lint
|
# name: Components lint
|
||||||
# runs-on: ubuntu-latest
|
# runs-on: ubuntu-latest
|
||||||
# needs: pre_check
|
# needs: pre_check
|
||||||
# if: ${{ needs.pre_check.outputs.should_skip_components != 'true' }}
|
# if: ${{ needs.pre_check.outputs.should_skip_components != 'true' }}
|
||||||
# defaults:
|
# defaults:
|
||||||
# run:
|
# run:
|
||||||
# shell: bash
|
# shell: bash
|
||||||
# working-directory: packages/components
|
# working-directory: packages/components
|
||||||
# steps:
|
# steps:
|
||||||
# - uses: actions/checkout@v3
|
# - uses: actions/checkout@v3
|
||||||
# - name: Check javascript, typescript, and markdown lint
|
# - name: Check javascript, typescript, and markdown lint
|
||||||
# uses: creyD/prettier_action@v4.2
|
# uses: creyD/prettier_action@v4.2
|
||||||
# with:
|
# with:
|
||||||
# dry: true
|
# dry: true
|
||||||
# prettier_options: --check packages/components --ignore-path packages/components/.prettierignore
|
# prettier_options: --check packages/components --ignore-path packages/components/.prettierignore
|
||||||
#
|
#
|
||||||
# components-bundle-build:
|
# components-bundle-build:
|
||||||
# name: Components bundle and build
|
# name: Components bundle and build
|
||||||
# runs-on: ubuntu-latest
|
# runs-on: ubuntu-latest
|
||||||
# needs: pre_check
|
# needs: pre_check
|
||||||
# if: ${{ (needs.pre_check.outputs.should_skip_components != 'true') || (needs.pre_check.outputs.should_skip_lang != 'true') }}
|
# if: ${{ (needs.pre_check.outputs.should_skip_components != 'true') || (needs.pre_check.outputs.should_skip_lang != 'true') }}
|
||||||
# defaults:
|
# defaults:
|
||||||
# run:
|
# run:
|
||||||
# shell: bash
|
# shell: bash
|
||||||
# working-directory: packages/components
|
# working-directory: packages/components
|
||||||
# steps:
|
# steps:
|
||||||
# - uses: actions/checkout@v3
|
# - uses: actions/checkout@v3
|
||||||
# - name: Install dependencies from monorepo level
|
# - name: Install dependencies from monorepo level
|
||||||
# run: cd ../../ && yarn
|
# run: cd ../../ && yarn
|
||||||
# - name: Build rescript codebase in squiggle-lang
|
# - name: Build rescript codebase in squiggle-lang
|
||||||
# run: cd ../squiggle-lang && yarn build
|
# run: cd ../squiggle-lang && yarn build
|
||||||
# - name: Run webpack
|
# - name: Run webpack
|
||||||
# run: yarn bundle
|
# run: yarn bundle
|
||||||
# - name: Build storybook
|
# - name: Build storybook
|
||||||
# run: yarn build
|
# run: yarn build
|
||||||
|
|
||||||
# website-lint:
|
# website-lint:
|
||||||
# name: Website lint
|
# name: Website lint
|
||||||
# runs-on: ubuntu-latest
|
# runs-on: ubuntu-latest
|
||||||
# needs: pre_check
|
# needs: pre_check
|
||||||
# if: ${{ needs.pre_check.outputs.should_skip_website != 'true' }}
|
# if: ${{ needs.pre_check.outputs.should_skip_website != 'true' }}
|
||||||
# defaults:
|
# defaults:
|
||||||
# run:
|
# run:
|
||||||
# shell: bash
|
# shell: bash
|
||||||
# working-directory: packages/website
|
# working-directory: packages/website
|
||||||
# steps:
|
# steps:
|
||||||
# - uses: actions/checkout@v3
|
# - uses: actions/checkout@v3
|
||||||
# - name: Check javascript, typescript, and markdown lint
|
# - name: Check javascript, typescript, and markdown lint
|
||||||
# uses: creyD/prettier_action@v4.2
|
# uses: creyD/prettier_action@v4.2
|
||||||
# with:
|
# with:
|
||||||
# dry: true
|
# dry: true
|
||||||
# prettier_options: --check packages/website
|
# prettier_options: --check packages/website
|
||||||
#
|
#
|
||||||
# website-build:
|
# website-build:
|
||||||
# name: Website build
|
# name: Website build
|
||||||
# runs-on: ubuntu-latest
|
# runs-on: ubuntu-latest
|
||||||
# needs: pre_check
|
# needs: pre_check
|
||||||
# if: ${{ (needs.pre_check.outputs.should_skip_website != 'true') || (needs.pre_check.outputs.should_skip_lang != 'true') || (needs.pre_check.outputs.should_skip_components != 'true') }}
|
# if: ${{ (needs.pre_check.outputs.should_skip_website != 'true') || (needs.pre_check.outputs.should_skip_lang != 'true') || (needs.pre_check.outputs.should_skip_components != 'true') }}
|
||||||
# defaults:
|
# defaults:
|
||||||
# run:
|
# run:
|
||||||
# shell: bash
|
# shell: bash
|
||||||
# working-directory: packages/website
|
# working-directory: packages/website
|
||||||
# steps:
|
# steps:
|
||||||
# - uses: actions/checkout@v3
|
# - uses: actions/checkout@v3
|
||||||
# - name: Install dependencies from monorepo level
|
# - name: Install dependencies from monorepo level
|
||||||
# run: cd ../../ && yarn
|
# run: cd ../../ && yarn
|
||||||
# - name: Build rescript in squiggle-lang
|
# - name: Build rescript in squiggle-lang
|
||||||
# run: cd ../squiggle-lang && yarn build
|
# run: cd ../squiggle-lang && yarn build
|
||||||
# - name: Build components
|
# - name: Build components
|
||||||
# run: cd ../components && yarn build
|
# run: cd ../components && yarn build
|
||||||
# - name: Build website assets
|
# - name: Build website assets
|
||||||
# run: yarn build
|
# run: yarn build
|
||||||
#
|
#
|
||||||
# vscode-ext-lint:
|
# vscode-ext-lint:
|
||||||
# name: VS Code extension lint
|
# name: VS Code extension lint
|
||||||
# runs-on: ubuntu-latest
|
# runs-on: ubuntu-latest
|
||||||
# needs: pre_check
|
# needs: pre_check
|
||||||
# if: ${{ needs.pre_check.outputs.should_skip_vscodeext != 'true' }}
|
# if: ${{ needs.pre_check.outputs.should_skip_vscodeext != 'true' }}
|
||||||
# defaults:
|
# defaults:
|
||||||
# run:
|
# run:
|
||||||
# shell: bash
|
# shell: bash
|
||||||
# working-directory: packages/vscode-ext
|
# working-directory: packages/vscode-ext
|
||||||
# steps:
|
# steps:
|
||||||
# - uses: actions/checkout@v3
|
# - uses: actions/checkout@v3
|
||||||
# - name: Check javascript, typescript, and markdown lint
|
# - name: Check javascript, typescript, and markdown lint
|
||||||
# uses: creyD/prettier_action@v4.2
|
# uses: creyD/prettier_action@v4.2
|
||||||
# with:
|
# with:
|
||||||
# dry: true
|
# dry: true
|
||||||
# prettier_options: --check packages/vscode-ext
|
# prettier_options: --check packages/vscode-ext
|
||||||
|
|
||||||
vscode-ext-build:
|
vscode-ext-build:
|
||||||
name: VS Code extension build
|
name: VS Code extension build
|
||||||
|
@ -204,7 +204,6 @@ jobs:
|
||||||
run: cd ../../ && yarn
|
run: cd ../../ && yarn
|
||||||
- name: Build
|
- name: Build
|
||||||
run: yarn compile
|
run: yarn compile
|
||||||
|
|
||||||
# cli-lint:
|
# cli-lint:
|
||||||
# name: CLI lint
|
# name: CLI lint
|
||||||
# runs-on: ubuntu-latest
|
# runs-on: ubuntu-latest
|
||||||
|
|
20
packages/squiggle-lang/__tests__/Stdlib_test.res
Normal file
20
packages/squiggle-lang/__tests__/Stdlib_test.res
Normal file
|
@ -0,0 +1,20 @@
|
||||||
|
open Jest
|
||||||
|
open Expect
|
||||||
|
|
||||||
|
let makeTest = (~only=false, str, item1, item2) =>
|
||||||
|
only
|
||||||
|
? Only.test(str, () => expect(item1)->toEqual(item2))
|
||||||
|
: test(str, () => expect(item1)->toEqual(item2))
|
||||||
|
|
||||||
|
describe("Stdlib", () => {
|
||||||
|
makeTest(
|
||||||
|
"Length of Random.sample",
|
||||||
|
Stdlib.Random.sample([1.0, 2.0], {probs: [0.5, 0.5], size: 10})->E.A.length,
|
||||||
|
10,
|
||||||
|
)
|
||||||
|
makeTest(
|
||||||
|
"Random.sample returns elements from input array (will fail with very slim probability)",
|
||||||
|
Stdlib.Random.sample([1.0, 2.0], {probs: [0.5, 0.5], size: 10})->E.A.uniq->E.A.Floats.sort,
|
||||||
|
[1.0, 2.0],
|
||||||
|
)
|
||||||
|
})
|
|
@ -18,6 +18,7 @@
|
||||||
"benchmark": "ts-node benchmark/conversion_tests.ts",
|
"benchmark": "ts-node benchmark/conversion_tests.ts",
|
||||||
"test": "jest",
|
"test": "jest",
|
||||||
"test:ts": "jest __tests__/TS/",
|
"test:ts": "jest __tests__/TS/",
|
||||||
|
"test:stdlib": "jest __tests__/Stdlib_test.bs.js",
|
||||||
"test:rescript": "jest --modulePathIgnorePatterns=__tests__/TS/*",
|
"test:rescript": "jest --modulePathIgnorePatterns=__tests__/TS/*",
|
||||||
"test:watch": "jest --watchAll",
|
"test:watch": "jest --watchAll",
|
||||||
"test:fnRegistry": "jest __tests__/SquiggleLibrary/SquiggleLibrary_FunctionRegistryLibrary_test.bs.js",
|
"test:fnRegistry": "jest __tests__/SquiggleLibrary/SquiggleLibrary_FunctionRegistryLibrary_test.bs.js",
|
||||||
|
|
|
@ -216,7 +216,7 @@ let rec run = (~env: env, functionCallInfo: functionCallInfo): outputType => {
|
||||||
| FromFloat(subFnName, x) => reCall(~functionCallInfo=FromFloat(subFnName, x), ())
|
| FromFloat(subFnName, x) => reCall(~functionCallInfo=FromFloat(subFnName, x), ())
|
||||||
| Mixture(dists) =>
|
| Mixture(dists) =>
|
||||||
dists
|
dists
|
||||||
->GenericDist.mixture(~scaleMultiplyFn=scaleMultiply, ~pointwiseAddFn=pointwiseAdd)
|
->GenericDist.mixture(~scaleMultiplyFn=scaleMultiply, ~pointwiseAddFn=pointwiseAdd, ~env)
|
||||||
->E.R2.fmap(r => Dist(r))
|
->E.R2.fmap(r => Dist(r))
|
||||||
->OutputLocal.fromResult
|
->OutputLocal.fromResult
|
||||||
| FromSamples(xs) =>
|
| FromSamples(xs) =>
|
||||||
|
|
|
@ -499,15 +499,30 @@ let pointwiseCombinationFloat = (
|
||||||
m->E.R2.fmap(r => DistributionTypes.PointSet(r))
|
m->E.R2.fmap(r => DistributionTypes.PointSet(r))
|
||||||
}
|
}
|
||||||
|
|
||||||
//Note: The result should always cumulatively sum to 1. This would be good to test.
|
//TODO: The result should always cumulatively sum to 1. This would be good to test.
|
||||||
//Note: If the inputs are not normalized, this will return poor results. The weights probably refer to the post-normalized forms. It would be good to apply a catch to this.
|
//TODO: If the inputs are not normalized, this will return poor results. The weights probably refer to the post-normalized forms. It would be good to apply a catch to this.
|
||||||
let mixture = (
|
let mixture = (
|
||||||
values: array<(t, float)>,
|
values: array<(t, float)>,
|
||||||
~scaleMultiplyFn: scaleMultiplyFn,
|
~scaleMultiplyFn: scaleMultiplyFn,
|
||||||
~pointwiseAddFn: pointwiseAddFn,
|
~pointwiseAddFn: pointwiseAddFn,
|
||||||
|
~env: env,
|
||||||
) => {
|
) => {
|
||||||
if E.A.length(values) == 0 {
|
let allValuesAreSampleSet = v => E.A.all(((t, _)) => isSampleSetSet(t), v)
|
||||||
|
|
||||||
|
if E.A.isEmpty(values) {
|
||||||
Error(DistributionTypes.OtherError("Mixture error: mixture must have at least 1 element"))
|
Error(DistributionTypes.OtherError("Mixture error: mixture must have at least 1 element"))
|
||||||
|
} else if allValuesAreSampleSet(values) {
|
||||||
|
let withSampleSetValues = values->E.A2.fmap(((value, weight)) =>
|
||||||
|
switch value {
|
||||||
|
| SampleSet(sampleSet) => Ok((sampleSet, weight))
|
||||||
|
| _ => Error("Unreachable")
|
||||||
|
}->E.R2.toExn("Mixture coding error: SampleSet expected. This should be inaccessible.")
|
||||||
|
)
|
||||||
|
let sampleSetMixture = SampleSetDist.mixture(withSampleSetValues, env.sampleCount)
|
||||||
|
switch sampleSetMixture {
|
||||||
|
| Ok(sampleSet) => Ok(DistributionTypes.SampleSet(sampleSet))
|
||||||
|
| Error(err) => Error(DistributionTypes.Error.sampleErrorToDistErr(err))
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
let totalWeight = values->E.A2.fmap(E.Tuple2.second)->E.A.Floats.sum
|
let totalWeight = values->E.A2.fmap(E.Tuple2.second)->E.A.Floats.sum
|
||||||
let properlyWeightedValues =
|
let properlyWeightedValues =
|
||||||
|
|
|
@ -81,6 +81,7 @@ let mixture: (
|
||||||
array<(t, float)>,
|
array<(t, float)>,
|
||||||
~scaleMultiplyFn: scaleMultiplyFn,
|
~scaleMultiplyFn: scaleMultiplyFn,
|
||||||
~pointwiseAddFn: pointwiseAddFn,
|
~pointwiseAddFn: pointwiseAddFn,
|
||||||
|
~env: env,
|
||||||
) => result<t, error>
|
) => result<t, error>
|
||||||
|
|
||||||
let isSymbolic: t => bool
|
let isSymbolic: t => bool
|
||||||
|
|
|
@ -224,3 +224,8 @@ module T = Dist({
|
||||||
XYShape.Analysis.getVarianceDangerously(t, mean, getMeanOfSquares)
|
XYShape.Analysis.getVarianceDangerously(t, mean, getMeanOfSquares)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
|
let sampleN = (t: t, n): array<float> => {
|
||||||
|
let normalized = t->T.normalize->getShape
|
||||||
|
Stdlib.Random.sample(normalized.xs, {probs: normalized.ys, size: n})
|
||||||
|
}
|
||||||
|
|
|
@ -257,3 +257,7 @@ let toSparkline = (t: t, bucketCount): result<string, PointSetTypes.sparklineErr
|
||||||
->E.O2.fmap(Continuous.downsampleEquallyOverX(bucketCount))
|
->E.O2.fmap(Continuous.downsampleEquallyOverX(bucketCount))
|
||||||
->E.O2.toResult(PointSetTypes.CannotSparklineDiscrete)
|
->E.O2.toResult(PointSetTypes.CannotSparklineDiscrete)
|
||||||
->E.R2.fmap(r => Continuous.getShape(r).ys->Sparklines.create())
|
->E.R2.fmap(r => Continuous.getShape(r).ys->Sparklines.create())
|
||||||
|
|
||||||
|
let makeDiscrete = (d): t => Discrete(d)
|
||||||
|
let makeContinuous = (d): t => Continuous(d)
|
||||||
|
let makeMixed = (d): t => Mixed(d)
|
||||||
|
|
|
@ -132,6 +132,25 @@ let stdev = t => T.get(t)->E.A.Floats.stdev
|
||||||
let variance = t => T.get(t)->E.A.Floats.variance
|
let variance = t => T.get(t)->E.A.Floats.variance
|
||||||
let percentile = (t, f) => T.get(t)->E.A.Floats.percentile(f)
|
let percentile = (t, f) => T.get(t)->E.A.Floats.percentile(f)
|
||||||
|
|
||||||
|
let mixture = (values: array<(t, float)>, intendedLength: int) => {
|
||||||
|
let totalWeight = values->E.A2.fmap(E.Tuple2.second)->E.A.Floats.sum
|
||||||
|
let discreteSamples =
|
||||||
|
values
|
||||||
|
->Belt.Array.mapWithIndex((i, (_, weight)) => (E.I.toFloat(i), weight /. totalWeight))
|
||||||
|
->XYShape.T.fromZippedArray
|
||||||
|
->Discrete.make
|
||||||
|
->Discrete.sampleN(intendedLength)
|
||||||
|
let dists = values->E.A2.fmap(E.Tuple2.first)->E.A2.fmap(T.get)
|
||||||
|
let samples =
|
||||||
|
discreteSamples
|
||||||
|
->Belt.Array.mapWithIndex((index, distIndexToChoose) => {
|
||||||
|
let chosenDist = E.A.get(dists, E.Float.toInt(distIndexToChoose))
|
||||||
|
chosenDist->E.O.bind(E.A.get(_, index))
|
||||||
|
})
|
||||||
|
->E.A.O.openIfAllSome
|
||||||
|
samples->E.O2.toExn("Mixture unreachable error")->T.make
|
||||||
|
}
|
||||||
|
|
||||||
let truncateLeft = (t, f) => T.get(t)->E.A2.filter(x => x >= f)->T.make
|
let truncateLeft = (t, f) => T.get(t)->E.A2.filter(x => x >= f)->T.make
|
||||||
let truncateRight = (t, f) => T.get(t)->E.A2.filter(x => x <= f)->T.make
|
let truncateRight = (t, f) => T.get(t)->E.A2.filter(x => x <= f)->T.make
|
||||||
|
|
||||||
|
|
|
@ -220,6 +220,7 @@ module I = {
|
||||||
let increment = n => n + 1
|
let increment = n => n + 1
|
||||||
let decrement = n => n - 1
|
let decrement = n => n - 1
|
||||||
let toString = Js.Int.toString
|
let toString = Js.Int.toString
|
||||||
|
let toFloat = Js.Int.toFloat
|
||||||
}
|
}
|
||||||
|
|
||||||
exception Assertion(string)
|
exception Assertion(string)
|
||||||
|
|
|
@ -38,3 +38,12 @@ module Logistic = {
|
||||||
@module external variance: (float, float) => float = "@stdlib/stats/base/dists/logistic/variance"
|
@module external variance: (float, float) => float = "@stdlib/stats/base/dists/logistic/variance"
|
||||||
let variance = variance
|
let variance = variance
|
||||||
}
|
}
|
||||||
|
|
||||||
|
module Random = {
|
||||||
|
type sampleArgs = {
|
||||||
|
probs: array<float>,
|
||||||
|
size: int,
|
||||||
|
}
|
||||||
|
@module external sample: (array<float>, sampleArgs) => array<float> = "@stdlib/random/sample"
|
||||||
|
let sample = sample
|
||||||
|
}
|
||||||
|
|
|
@ -14,6 +14,7 @@ module.exports = {
|
||||||
},
|
},
|
||||||
resolve: {
|
resolve: {
|
||||||
extensions: [".tsx", ".ts", ".js"],
|
extensions: [".tsx", ".ts", ".js"],
|
||||||
|
fallback: { buffer: ["@stdlib/buffer"] },
|
||||||
},
|
},
|
||||||
output: {
|
output: {
|
||||||
filename: "bundle.js",
|
filename: "bundle.js",
|
||||||
|
|
Loading…
Reference in New Issue
Block a user