Merge branch 'develop' into overhaul

This commit is contained in:
Vyacheslav Matyukhin 2022-09-24 20:29:00 +04:00
commit 2e9dabccd9
No known key found for this signature in database
GPG Key ID: 3D2A774C5489F96C
3 changed files with 81 additions and 61 deletions

View File

@ -1,5 +1,11 @@
import * as yup from "yup"; import * as yup from "yup";
import { SqDistribution, result, SqRecord } from "@quri/squiggle-lang"; import {
SqValue,
SqValueTag,
SqDistribution,
result,
SqRecord,
} from "@quri/squiggle-lang";
export type LabeledDistribution = { export type LabeledDistribution = {
name: string; name: string;
@ -21,48 +27,55 @@ function ok<a, b>(x: a): result<a, b> {
const schema = yup const schema = yup
.object() .object()
.strict()
.noUnknown() .noUnknown()
.strict()
.shape({ .shape({
distributions: yup.object().shape({ distributions: yup
tag: yup.mixed().oneOf(["array"]),
value: yup
.array() .array()
.required()
.of( .of(
yup.object().shape({ yup.object().required().shape({
tag: yup.mixed().oneOf(["record"]), name: yup.string().required(),
value: yup.object({ distribution: yup.mixed().required(),
name: yup.object().shape({
tag: yup.mixed().oneOf(["string"]),
value: yup.string().required(),
}),
// color: yup
// .object({
// tag: yup.mixed().oneOf(["string"]),
// value: yup.string().required(),
// })
// .default(undefined),
distribution: yup.object({
tag: yup.mixed().oneOf(["distribution"]),
value: yup.mixed(),
}),
}),
}) })
) ),
.required(),
}),
}); });
type JsonObject =
| string
| { [key: string]: JsonObject }
| JsonObject[]
| SqDistribution;
function toJson(val: SqValue): JsonObject {
if (val.tag === SqValueTag.String) {
return val.value;
} else if (val.tag === SqValueTag.Record) {
return toJsonRecord(val.value);
} else if (val.tag === SqValueTag.Array) {
return val.value.getValues().map(toJson);
} else if (val.tag === SqValueTag.Distribution) {
return val.value;
} else {
throw new Error("Could not parse object of type " + val.tag);
}
}
function toJsonRecord(val: SqRecord): JsonObject {
let recordObject: JsonObject = {};
val.entries().forEach(([key, value]) => (recordObject[key] = toJson(value)));
return recordObject;
}
export function parsePlot(record: SqRecord): result<Plot, string> { export function parsePlot(record: SqRecord): result<Plot, string> {
try { try {
const plotRecord = schema.validateSync(record); const plotRecord = schema.validateSync(toJsonRecord(record));
return ok({ if (plotRecord.distributions) {
distributions: plotRecord.distributions.value.map((x) => ({ return ok({ distributions: plotRecord.distributions.map((x) => x) });
name: x.value.name.value, } else {
// color: x.value.color?.value, // not supported yet // I have no idea why yup's typings thinks this is possible
distribution: x.value.distribution.value, return error("no distributions field. Should never get here");
})), }
});
} catch (e) { } catch (e) {
const message = e instanceof Error ? e.message : "Unknown error"; const message = e instanceof Error ? e.message : "Unknown error";
return error(message); return error(message);

View File

@ -25,6 +25,12 @@ describe("Continuous and discrete splits", () => {
([1.33455, 1.432, 2.0, 2.0, 3.5, 3.5, 3.5], []), ([1.33455, 1.432, 2.0, 2.0, 3.5, 3.5, 3.5], []),
) )
makeTest(
"more general test",
prepareInputs([10., 10., 11., 11., 11., 12., 13., 13., 13., 13., 13., 14.], 3),
([10., 10., 12., 14.], [(11., 3.), (13., 5.)]),
)
let makeDuplicatedArray = count => { let makeDuplicatedArray = count => {
let arr = Belt.Array.range(1, count) |> E.A.fmap(float_of_int) let arr = Belt.Array.range(1, count) |> E.A.fmap(float_of_int)
let sorted = arr |> Belt.SortArray.stableSortBy(_, compare) let sorted = arr |> Belt.SortArray.stableSortBy(_, compare)

View File

@ -309,9 +309,12 @@ module Floats = {
continuous samples and discrete samples. The discrete samples are stored in a mutable map. continuous samples and discrete samples. The discrete samples are stored in a mutable map.
Samples are thought to be discrete if they have at least `minDiscreteWight` duplicates. Samples are thought to be discrete if they have at least `minDiscreteWight` duplicates.
If the min discreet weight is 4, that would mean that at least four elements needed from a specific If the min discrete weight is 4, that would mean that at least four elements needed from a specific
value for that to be kept as discrete. This is important because in some cases, we can expect that value for that to be kept as discrete. This is important because in some cases, we can expect that
some common elements will be generated by regular operations. The final continous array will be sorted. some common elements will be generated by regular operations. The final continuous array will be sorted.
This function is performance-critical, don't change it significantly without benchmarking
SampleSet->PointSet conversion performance.
*/ */
let splitContinuousAndDiscreteForMinWeight = ( let splitContinuousAndDiscreteForMinWeight = (
sortedArray: array<float>, sortedArray: array<float>,
@ -320,34 +323,32 @@ module Floats = {
let continuous: array<float> = [] let continuous: array<float> = []
let discrete = FloatFloatMap.empty() let discrete = FloatFloatMap.empty()
let flush = (cnt: int, value: float): unit => { let addData = (count: int, value: float): unit => {
if cnt >= minDiscreteWeight { if count >= minDiscreteWeight {
FloatFloatMap.add(value, cnt->Belt.Int.toFloat, discrete) FloatFloatMap.add(value, count->Belt.Int.toFloat, discrete)
} else { } else {
for _ in 1 to cnt { for _ in 1 to count {
let _ = continuous->Js.Array2.push(value) continuous->Js.Array2.push(value)->ignore
} }
} }
} }
if sortedArray->Js.Array2.length != 0 { let (finalCount, finalValue) = sortedArray->Belt.Array.reduce(
let (finalCnt, finalValue) = sortedArray->Belt.Array.reduce(
// initial prev value doesn't matter; if it collides with the first element of the array, flush won't do anything // initial prev value doesn't matter; if it collides with the first element of the array, flush won't do anything
(0, 0.), (0, 0.),
((cnt, prev), element) => { ((count, prev), element) => {
if element == prev { if element == prev {
(cnt + 1, prev) (count + 1, prev)
} else { } else {
// new value, process previous ones // new value, process previous ones
flush(cnt, prev) addData(count, prev)
(1, element) (1, element)
} }
}, },
) )
// flush final values // flush final values
flush(finalCnt, finalValue) addData(finalCount, finalValue)
}
(continuous, discrete) (continuous, discrete)
} }