Merge branch 'develop' into overhaul

This commit is contained in:
Vyacheslav Matyukhin 2022-09-24 20:29:00 +04:00
commit 2e9dabccd9
No known key found for this signature in database
GPG Key ID: 3D2A774C5489F96C
3 changed files with 81 additions and 61 deletions

View File

@ -1,5 +1,11 @@
import * as yup from "yup";
import { SqDistribution, result, SqRecord } from "@quri/squiggle-lang";
import {
SqValue,
SqValueTag,
SqDistribution,
result,
SqRecord,
} from "@quri/squiggle-lang";
export type LabeledDistribution = {
name: string;
@ -21,48 +27,55 @@ function ok<a, b>(x: a): result<a, b> {
const schema = yup
.object()
.strict()
.noUnknown()
.strict()
.shape({
distributions: yup.object().shape({
tag: yup.mixed().oneOf(["array"]),
value: yup
.array()
.of(
yup.object().shape({
tag: yup.mixed().oneOf(["record"]),
value: yup.object({
name: yup.object().shape({
tag: yup.mixed().oneOf(["string"]),
value: yup.string().required(),
}),
// color: yup
// .object({
// tag: yup.mixed().oneOf(["string"]),
// value: yup.string().required(),
// })
// .default(undefined),
distribution: yup.object({
tag: yup.mixed().oneOf(["distribution"]),
value: yup.mixed(),
}),
}),
})
)
.required(),
}),
distributions: yup
.array()
.required()
.of(
yup.object().required().shape({
name: yup.string().required(),
distribution: yup.mixed().required(),
})
),
});
type JsonObject =
| string
| { [key: string]: JsonObject }
| JsonObject[]
| SqDistribution;
function toJson(val: SqValue): JsonObject {
if (val.tag === SqValueTag.String) {
return val.value;
} else if (val.tag === SqValueTag.Record) {
return toJsonRecord(val.value);
} else if (val.tag === SqValueTag.Array) {
return val.value.getValues().map(toJson);
} else if (val.tag === SqValueTag.Distribution) {
return val.value;
} else {
throw new Error("Could not parse object of type " + val.tag);
}
}
function toJsonRecord(val: SqRecord): JsonObject {
let recordObject: JsonObject = {};
val.entries().forEach(([key, value]) => (recordObject[key] = toJson(value)));
return recordObject;
}
export function parsePlot(record: SqRecord): result<Plot, string> {
try {
const plotRecord = schema.validateSync(record);
return ok({
distributions: plotRecord.distributions.value.map((x) => ({
name: x.value.name.value,
// color: x.value.color?.value, // not supported yet
distribution: x.value.distribution.value,
})),
});
const plotRecord = schema.validateSync(toJsonRecord(record));
if (plotRecord.distributions) {
return ok({ distributions: plotRecord.distributions.map((x) => x) });
} else {
// I have no idea why yup's typings thinks this is possible
return error("no distributions field. Should never get here");
}
} catch (e) {
const message = e instanceof Error ? e.message : "Unknown error";
return error(message);

View File

@ -25,6 +25,12 @@ describe("Continuous and discrete splits", () => {
([1.33455, 1.432, 2.0, 2.0, 3.5, 3.5, 3.5], []),
)
makeTest(
"more general test",
prepareInputs([10., 10., 11., 11., 11., 12., 13., 13., 13., 13., 13., 14.], 3),
([10., 10., 12., 14.], [(11., 3.), (13., 5.)]),
)
let makeDuplicatedArray = count => {
let arr = Belt.Array.range(1, count) |> E.A.fmap(float_of_int)
let sorted = arr |> Belt.SortArray.stableSortBy(_, compare)

View File

@ -309,9 +309,12 @@ module Floats = {
continuous samples and discrete samples. The discrete samples are stored in a mutable map.
Samples are thought to be discrete if they have at least `minDiscreteWight` duplicates.
If the min discreet weight is 4, that would mean that at least four elements needed from a specific
If the min discrete weight is 4, that would mean that at least four elements needed from a specific
value for that to be kept as discrete. This is important because in some cases, we can expect that
some common elements will be generated by regular operations. The final continous array will be sorted.
some common elements will be generated by regular operations. The final continuous array will be sorted.
This function is performance-critical, don't change it significantly without benchmarking
SampleSet->PointSet conversion performance.
*/
let splitContinuousAndDiscreteForMinWeight = (
sortedArray: array<float>,
@ -320,34 +323,32 @@ module Floats = {
let continuous: array<float> = []
let discrete = FloatFloatMap.empty()
let flush = (cnt: int, value: float): unit => {
if cnt >= minDiscreteWeight {
FloatFloatMap.add(value, cnt->Belt.Int.toFloat, discrete)
let addData = (count: int, value: float): unit => {
if count >= minDiscreteWeight {
FloatFloatMap.add(value, count->Belt.Int.toFloat, discrete)
} else {
for _ in 1 to cnt {
let _ = continuous->Js.Array2.push(value)
for _ in 1 to count {
continuous->Js.Array2.push(value)->ignore
}
}
}
if sortedArray->Js.Array2.length != 0 {
let (finalCnt, finalValue) = sortedArray->Belt.Array.reduce(
// initial prev value doesn't matter; if it collides with the first element of the array, flush won't do anything
(0, 0.),
((cnt, prev), element) => {
if element == prev {
(cnt + 1, prev)
} else {
// new value, process previous ones
flush(cnt, prev)
(1, element)
}
},
)
let (finalCount, finalValue) = sortedArray->Belt.Array.reduce(
// initial prev value doesn't matter; if it collides with the first element of the array, flush won't do anything
(0, 0.),
((count, prev), element) => {
if element == prev {
(count + 1, prev)
} else {
// new value, process previous ones
addData(count, prev)
(1, element)
}
},
)
// flush final values
flush(finalCnt, finalValue)
}
// flush final values
addData(finalCount, finalValue)
(continuous, discrete)
}