Merge pull request #1002 from quantified-uncertainty/sampleset-mixture

Sampleset mixture
This commit is contained in:
Quinn 2022-09-01 03:00:17 -04:00 committed by GitHub
commit e6d543daef
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 189 additions and 114 deletions

View File

@ -49,26 +49,26 @@ jobs:
with: with:
paths: '["packages/cli/**"]' paths: '["packages/cli/**"]'
# lang-lint: # lang-lint:
# name: Language lint # name: Language lint
# runs-on: ubuntu-latest # runs-on: ubuntu-latest
# needs: pre_check # needs: pre_check
# if: ${{ needs.pre_check.outputs.should_skip_lang != 'true' }} # if: ${{ needs.pre_check.outputs.should_skip_lang != 'true' }}
# defaults: # defaults:
# run: # run:
# shell: bash # shell: bash
# working-directory: packages/squiggle-lang # working-directory: packages/squiggle-lang
# steps: # steps:
# - uses: actions/checkout@v3 # - uses: actions/checkout@v3
# - name: Install Dependencies # - name: Install Dependencies
# run: cd ../../ && yarn # run: cd ../../ && yarn
# - name: Check rescript lint # - name: Check rescript lint
# run: yarn lint:rescript # run: yarn lint:rescript
# - name: Check javascript, typescript, and markdown lint # - name: Check javascript, typescript, and markdown lint
# uses: creyD/prettier_action@v4.2 # uses: creyD/prettier_action@v4.2
# with: # with:
# dry: true # dry: true
# prettier_options: --check packages/squiggle-lang # prettier_options: --check packages/squiggle-lang
lang-build-test-bundle: lang-build-test-bundle:
name: Language build, test, and bundle name: Language build, test, and bundle
@ -98,96 +98,96 @@ jobs:
- name: Upload typescript coverage report - name: Upload typescript coverage report
run: yarn coverage:ts:ci run: yarn coverage:ts:ci
# components-lint: # components-lint:
# name: Components lint # name: Components lint
# runs-on: ubuntu-latest # runs-on: ubuntu-latest
# needs: pre_check # needs: pre_check
# if: ${{ needs.pre_check.outputs.should_skip_components != 'true' }} # if: ${{ needs.pre_check.outputs.should_skip_components != 'true' }}
# defaults: # defaults:
# run: # run:
# shell: bash # shell: bash
# working-directory: packages/components # working-directory: packages/components
# steps: # steps:
# - uses: actions/checkout@v3 # - uses: actions/checkout@v3
# - name: Check javascript, typescript, and markdown lint # - name: Check javascript, typescript, and markdown lint
# uses: creyD/prettier_action@v4.2 # uses: creyD/prettier_action@v4.2
# with: # with:
# dry: true # dry: true
# prettier_options: --check packages/components --ignore-path packages/components/.prettierignore # prettier_options: --check packages/components --ignore-path packages/components/.prettierignore
# #
# components-bundle-build: # components-bundle-build:
# name: Components bundle and build # name: Components bundle and build
# runs-on: ubuntu-latest # runs-on: ubuntu-latest
# needs: pre_check # needs: pre_check
# if: ${{ (needs.pre_check.outputs.should_skip_components != 'true') || (needs.pre_check.outputs.should_skip_lang != 'true') }} # if: ${{ (needs.pre_check.outputs.should_skip_components != 'true') || (needs.pre_check.outputs.should_skip_lang != 'true') }}
# defaults: # defaults:
# run: # run:
# shell: bash # shell: bash
# working-directory: packages/components # working-directory: packages/components
# steps: # steps:
# - uses: actions/checkout@v3 # - uses: actions/checkout@v3
# - name: Install dependencies from monorepo level # - name: Install dependencies from monorepo level
# run: cd ../../ && yarn # run: cd ../../ && yarn
# - name: Build rescript codebase in squiggle-lang # - name: Build rescript codebase in squiggle-lang
# run: cd ../squiggle-lang && yarn build # run: cd ../squiggle-lang && yarn build
# - name: Run webpack # - name: Run webpack
# run: yarn bundle # run: yarn bundle
# - name: Build storybook # - name: Build storybook
# run: yarn build # run: yarn build
# website-lint: # website-lint:
# name: Website lint # name: Website lint
# runs-on: ubuntu-latest # runs-on: ubuntu-latest
# needs: pre_check # needs: pre_check
# if: ${{ needs.pre_check.outputs.should_skip_website != 'true' }} # if: ${{ needs.pre_check.outputs.should_skip_website != 'true' }}
# defaults: # defaults:
# run: # run:
# shell: bash # shell: bash
# working-directory: packages/website # working-directory: packages/website
# steps: # steps:
# - uses: actions/checkout@v3 # - uses: actions/checkout@v3
# - name: Check javascript, typescript, and markdown lint # - name: Check javascript, typescript, and markdown lint
# uses: creyD/prettier_action@v4.2 # uses: creyD/prettier_action@v4.2
# with: # with:
# dry: true # dry: true
# prettier_options: --check packages/website # prettier_options: --check packages/website
# #
# website-build: # website-build:
# name: Website build # name: Website build
# runs-on: ubuntu-latest # runs-on: ubuntu-latest
# needs: pre_check # needs: pre_check
# if: ${{ (needs.pre_check.outputs.should_skip_website != 'true') || (needs.pre_check.outputs.should_skip_lang != 'true') || (needs.pre_check.outputs.should_skip_components != 'true') }} # if: ${{ (needs.pre_check.outputs.should_skip_website != 'true') || (needs.pre_check.outputs.should_skip_lang != 'true') || (needs.pre_check.outputs.should_skip_components != 'true') }}
# defaults: # defaults:
# run: # run:
# shell: bash # shell: bash
# working-directory: packages/website # working-directory: packages/website
# steps: # steps:
# - uses: actions/checkout@v3 # - uses: actions/checkout@v3
# - name: Install dependencies from monorepo level # - name: Install dependencies from monorepo level
# run: cd ../../ && yarn # run: cd ../../ && yarn
# - name: Build rescript in squiggle-lang # - name: Build rescript in squiggle-lang
# run: cd ../squiggle-lang && yarn build # run: cd ../squiggle-lang && yarn build
# - name: Build components # - name: Build components
# run: cd ../components && yarn build # run: cd ../components && yarn build
# - name: Build website assets # - name: Build website assets
# run: yarn build # run: yarn build
# #
# vscode-ext-lint: # vscode-ext-lint:
# name: VS Code extension lint # name: VS Code extension lint
# runs-on: ubuntu-latest # runs-on: ubuntu-latest
# needs: pre_check # needs: pre_check
# if: ${{ needs.pre_check.outputs.should_skip_vscodeext != 'true' }} # if: ${{ needs.pre_check.outputs.should_skip_vscodeext != 'true' }}
# defaults: # defaults:
# run: # run:
# shell: bash # shell: bash
# working-directory: packages/vscode-ext # working-directory: packages/vscode-ext
# steps: # steps:
# - uses: actions/checkout@v3 # - uses: actions/checkout@v3
# - name: Check javascript, typescript, and markdown lint # - name: Check javascript, typescript, and markdown lint
# uses: creyD/prettier_action@v4.2 # uses: creyD/prettier_action@v4.2
# with: # with:
# dry: true # dry: true
# prettier_options: --check packages/vscode-ext # prettier_options: --check packages/vscode-ext
vscode-ext-build: vscode-ext-build:
name: VS Code extension build name: VS Code extension build
@ -204,7 +204,6 @@ jobs:
run: cd ../../ && yarn run: cd ../../ && yarn
- name: Build - name: Build
run: yarn compile run: yarn compile
# cli-lint: # cli-lint:
# name: CLI lint # name: CLI lint
# runs-on: ubuntu-latest # runs-on: ubuntu-latest

View File

@ -0,0 +1,20 @@
open Jest
open Expect
let makeTest = (~only=false, str, item1, item2) =>
only
? Only.test(str, () => expect(item1)->toEqual(item2))
: test(str, () => expect(item1)->toEqual(item2))
describe("Stdlib", () => {
makeTest(
"Length of Random.sample",
Stdlib.Random.sample([1.0, 2.0], {probs: [0.5, 0.5], size: 10})->E.A.length,
10,
)
makeTest(
"Random.sample returns elements from input array (will fail with very slim probability)",
Stdlib.Random.sample([1.0, 2.0], {probs: [0.5, 0.5], size: 10})->E.A.uniq->E.A.Floats.sort,
[1.0, 2.0],
)
})

View File

@ -18,6 +18,7 @@
"benchmark": "ts-node benchmark/conversion_tests.ts", "benchmark": "ts-node benchmark/conversion_tests.ts",
"test": "jest", "test": "jest",
"test:ts": "jest __tests__/TS/", "test:ts": "jest __tests__/TS/",
"test:stdlib": "jest __tests__/Stdlib_test.bs.js",
"test:rescript": "jest --modulePathIgnorePatterns=__tests__/TS/*", "test:rescript": "jest --modulePathIgnorePatterns=__tests__/TS/*",
"test:watch": "jest --watchAll", "test:watch": "jest --watchAll",
"test:fnRegistry": "jest __tests__/SquiggleLibrary/SquiggleLibrary_FunctionRegistryLibrary_test.bs.js", "test:fnRegistry": "jest __tests__/SquiggleLibrary/SquiggleLibrary_FunctionRegistryLibrary_test.bs.js",

View File

@ -216,7 +216,7 @@ let rec run = (~env: env, functionCallInfo: functionCallInfo): outputType => {
| FromFloat(subFnName, x) => reCall(~functionCallInfo=FromFloat(subFnName, x), ()) | FromFloat(subFnName, x) => reCall(~functionCallInfo=FromFloat(subFnName, x), ())
| Mixture(dists) => | Mixture(dists) =>
dists dists
->GenericDist.mixture(~scaleMultiplyFn=scaleMultiply, ~pointwiseAddFn=pointwiseAdd) ->GenericDist.mixture(~scaleMultiplyFn=scaleMultiply, ~pointwiseAddFn=pointwiseAdd, ~env)
->E.R2.fmap(r => Dist(r)) ->E.R2.fmap(r => Dist(r))
->OutputLocal.fromResult ->OutputLocal.fromResult
| FromSamples(xs) => | FromSamples(xs) =>

View File

@ -499,15 +499,30 @@ let pointwiseCombinationFloat = (
m->E.R2.fmap(r => DistributionTypes.PointSet(r)) m->E.R2.fmap(r => DistributionTypes.PointSet(r))
} }
//Note: The result should always cumulatively sum to 1. This would be good to test. //TODO: The result should always cumulatively sum to 1. This would be good to test.
//Note: If the inputs are not normalized, this will return poor results. The weights probably refer to the post-normalized forms. It would be good to apply a catch to this. //TODO: If the inputs are not normalized, this will return poor results. The weights probably refer to the post-normalized forms. It would be good to apply a catch to this.
let mixture = ( let mixture = (
values: array<(t, float)>, values: array<(t, float)>,
~scaleMultiplyFn: scaleMultiplyFn, ~scaleMultiplyFn: scaleMultiplyFn,
~pointwiseAddFn: pointwiseAddFn, ~pointwiseAddFn: pointwiseAddFn,
~env: env,
) => { ) => {
if E.A.length(values) == 0 { let allValuesAreSampleSet = v => E.A.all(((t, _)) => isSampleSetSet(t), v)
if E.A.isEmpty(values) {
Error(DistributionTypes.OtherError("Mixture error: mixture must have at least 1 element")) Error(DistributionTypes.OtherError("Mixture error: mixture must have at least 1 element"))
} else if allValuesAreSampleSet(values) {
let withSampleSetValues = values->E.A2.fmap(((value, weight)) =>
switch value {
| SampleSet(sampleSet) => Ok((sampleSet, weight))
| _ => Error("Unreachable")
}->E.R2.toExn("Mixture coding error: SampleSet expected. This should be inaccessible.")
)
let sampleSetMixture = SampleSetDist.mixture(withSampleSetValues, env.sampleCount)
switch sampleSetMixture {
| Ok(sampleSet) => Ok(DistributionTypes.SampleSet(sampleSet))
| Error(err) => Error(DistributionTypes.Error.sampleErrorToDistErr(err))
}
} else { } else {
let totalWeight = values->E.A2.fmap(E.Tuple2.second)->E.A.Floats.sum let totalWeight = values->E.A2.fmap(E.Tuple2.second)->E.A.Floats.sum
let properlyWeightedValues = let properlyWeightedValues =

View File

@ -81,6 +81,7 @@ let mixture: (
array<(t, float)>, array<(t, float)>,
~scaleMultiplyFn: scaleMultiplyFn, ~scaleMultiplyFn: scaleMultiplyFn,
~pointwiseAddFn: pointwiseAddFn, ~pointwiseAddFn: pointwiseAddFn,
~env: env,
) => result<t, error> ) => result<t, error>
let isSymbolic: t => bool let isSymbolic: t => bool

View File

@ -224,3 +224,8 @@ module T = Dist({
XYShape.Analysis.getVarianceDangerously(t, mean, getMeanOfSquares) XYShape.Analysis.getVarianceDangerously(t, mean, getMeanOfSquares)
} }
}) })
let sampleN = (t: t, n): array<float> => {
let normalized = t->T.normalize->getShape
Stdlib.Random.sample(normalized.xs, {probs: normalized.ys, size: n})
}

View File

@ -257,3 +257,7 @@ let toSparkline = (t: t, bucketCount): result<string, PointSetTypes.sparklineErr
->E.O2.fmap(Continuous.downsampleEquallyOverX(bucketCount)) ->E.O2.fmap(Continuous.downsampleEquallyOverX(bucketCount))
->E.O2.toResult(PointSetTypes.CannotSparklineDiscrete) ->E.O2.toResult(PointSetTypes.CannotSparklineDiscrete)
->E.R2.fmap(r => Continuous.getShape(r).ys->Sparklines.create()) ->E.R2.fmap(r => Continuous.getShape(r).ys->Sparklines.create())
let makeDiscrete = (d): t => Discrete(d)
let makeContinuous = (d): t => Continuous(d)
let makeMixed = (d): t => Mixed(d)

View File

@ -132,6 +132,25 @@ let stdev = t => T.get(t)->E.A.Floats.stdev
let variance = t => T.get(t)->E.A.Floats.variance let variance = t => T.get(t)->E.A.Floats.variance
let percentile = (t, f) => T.get(t)->E.A.Floats.percentile(f) let percentile = (t, f) => T.get(t)->E.A.Floats.percentile(f)
let mixture = (values: array<(t, float)>, intendedLength: int) => {
let totalWeight = values->E.A2.fmap(E.Tuple2.second)->E.A.Floats.sum
let discreteSamples =
values
->Belt.Array.mapWithIndex((i, (_, weight)) => (E.I.toFloat(i), weight /. totalWeight))
->XYShape.T.fromZippedArray
->Discrete.make
->Discrete.sampleN(intendedLength)
let dists = values->E.A2.fmap(E.Tuple2.first)->E.A2.fmap(T.get)
let samples =
discreteSamples
->Belt.Array.mapWithIndex((index, distIndexToChoose) => {
let chosenDist = E.A.get(dists, E.Float.toInt(distIndexToChoose))
chosenDist->E.O.bind(E.A.get(_, index))
})
->E.A.O.openIfAllSome
samples->E.O2.toExn("Mixture unreachable error")->T.make
}
let truncateLeft = (t, f) => T.get(t)->E.A2.filter(x => x >= f)->T.make let truncateLeft = (t, f) => T.get(t)->E.A2.filter(x => x >= f)->T.make
let truncateRight = (t, f) => T.get(t)->E.A2.filter(x => x <= f)->T.make let truncateRight = (t, f) => T.get(t)->E.A2.filter(x => x <= f)->T.make

View File

@ -220,6 +220,7 @@ module I = {
let increment = n => n + 1 let increment = n => n + 1
let decrement = n => n - 1 let decrement = n => n - 1
let toString = Js.Int.toString let toString = Js.Int.toString
let toFloat = Js.Int.toFloat
} }
exception Assertion(string) exception Assertion(string)

View File

@ -38,3 +38,12 @@ module Logistic = {
@module external variance: (float, float) => float = "@stdlib/stats/base/dists/logistic/variance" @module external variance: (float, float) => float = "@stdlib/stats/base/dists/logistic/variance"
let variance = variance let variance = variance
} }
module Random = {
type sampleArgs = {
probs: array<float>,
size: int,
}
@module external sample: (array<float>, sampleArgs) => array<float> = "@stdlib/random/sample"
let sample = sample
}

View File

@ -14,6 +14,7 @@ module.exports = {
}, },
resolve: { resolve: {
extensions: [".tsx", ".ts", ".js"], extensions: [".tsx", ".ts", ".js"],
fallback: { buffer: ["@stdlib/buffer"] },
}, },
output: { output: {
filename: "bundle.js", filename: "bundle.js",