Merge branch 'develop' into overhaul
This commit is contained in:
commit
2e9dabccd9
|
@ -1,5 +1,11 @@
|
|||
import * as yup from "yup";
|
||||
import { SqDistribution, result, SqRecord } from "@quri/squiggle-lang";
|
||||
import {
|
||||
SqValue,
|
||||
SqValueTag,
|
||||
SqDistribution,
|
||||
result,
|
||||
SqRecord,
|
||||
} from "@quri/squiggle-lang";
|
||||
|
||||
export type LabeledDistribution = {
|
||||
name: string;
|
||||
|
@ -21,48 +27,55 @@ function ok<a, b>(x: a): result<a, b> {
|
|||
|
||||
const schema = yup
|
||||
.object()
|
||||
.strict()
|
||||
.noUnknown()
|
||||
.strict()
|
||||
.shape({
|
||||
distributions: yup.object().shape({
|
||||
tag: yup.mixed().oneOf(["array"]),
|
||||
value: yup
|
||||
distributions: yup
|
||||
.array()
|
||||
.required()
|
||||
.of(
|
||||
yup.object().shape({
|
||||
tag: yup.mixed().oneOf(["record"]),
|
||||
value: yup.object({
|
||||
name: yup.object().shape({
|
||||
tag: yup.mixed().oneOf(["string"]),
|
||||
value: yup.string().required(),
|
||||
}),
|
||||
// color: yup
|
||||
// .object({
|
||||
// tag: yup.mixed().oneOf(["string"]),
|
||||
// value: yup.string().required(),
|
||||
// })
|
||||
// .default(undefined),
|
||||
distribution: yup.object({
|
||||
tag: yup.mixed().oneOf(["distribution"]),
|
||||
value: yup.mixed(),
|
||||
}),
|
||||
}),
|
||||
yup.object().required().shape({
|
||||
name: yup.string().required(),
|
||||
distribution: yup.mixed().required(),
|
||||
})
|
||||
)
|
||||
.required(),
|
||||
}),
|
||||
),
|
||||
});
|
||||
|
||||
type JsonObject =
|
||||
| string
|
||||
| { [key: string]: JsonObject }
|
||||
| JsonObject[]
|
||||
| SqDistribution;
|
||||
|
||||
function toJson(val: SqValue): JsonObject {
|
||||
if (val.tag === SqValueTag.String) {
|
||||
return val.value;
|
||||
} else if (val.tag === SqValueTag.Record) {
|
||||
return toJsonRecord(val.value);
|
||||
} else if (val.tag === SqValueTag.Array) {
|
||||
return val.value.getValues().map(toJson);
|
||||
} else if (val.tag === SqValueTag.Distribution) {
|
||||
return val.value;
|
||||
} else {
|
||||
throw new Error("Could not parse object of type " + val.tag);
|
||||
}
|
||||
}
|
||||
|
||||
function toJsonRecord(val: SqRecord): JsonObject {
|
||||
let recordObject: JsonObject = {};
|
||||
val.entries().forEach(([key, value]) => (recordObject[key] = toJson(value)));
|
||||
return recordObject;
|
||||
}
|
||||
|
||||
export function parsePlot(record: SqRecord): result<Plot, string> {
|
||||
try {
|
||||
const plotRecord = schema.validateSync(record);
|
||||
return ok({
|
||||
distributions: plotRecord.distributions.value.map((x) => ({
|
||||
name: x.value.name.value,
|
||||
// color: x.value.color?.value, // not supported yet
|
||||
distribution: x.value.distribution.value,
|
||||
})),
|
||||
});
|
||||
const plotRecord = schema.validateSync(toJsonRecord(record));
|
||||
if (plotRecord.distributions) {
|
||||
return ok({ distributions: plotRecord.distributions.map((x) => x) });
|
||||
} else {
|
||||
// I have no idea why yup's typings thinks this is possible
|
||||
return error("no distributions field. Should never get here");
|
||||
}
|
||||
} catch (e) {
|
||||
const message = e instanceof Error ? e.message : "Unknown error";
|
||||
return error(message);
|
||||
|
|
|
@ -25,6 +25,12 @@ describe("Continuous and discrete splits", () => {
|
|||
([1.33455, 1.432, 2.0, 2.0, 3.5, 3.5, 3.5], []),
|
||||
)
|
||||
|
||||
makeTest(
|
||||
"more general test",
|
||||
prepareInputs([10., 10., 11., 11., 11., 12., 13., 13., 13., 13., 13., 14.], 3),
|
||||
([10., 10., 12., 14.], [(11., 3.), (13., 5.)]),
|
||||
)
|
||||
|
||||
let makeDuplicatedArray = count => {
|
||||
let arr = Belt.Array.range(1, count) |> E.A.fmap(float_of_int)
|
||||
let sorted = arr |> Belt.SortArray.stableSortBy(_, compare)
|
||||
|
|
|
@ -309,9 +309,12 @@ module Floats = {
|
|||
continuous samples and discrete samples. The discrete samples are stored in a mutable map.
|
||||
Samples are thought to be discrete if they have at least `minDiscreteWight` duplicates.
|
||||
|
||||
If the min discreet weight is 4, that would mean that at least four elements needed from a specific
|
||||
If the min discrete weight is 4, that would mean that at least four elements needed from a specific
|
||||
value for that to be kept as discrete. This is important because in some cases, we can expect that
|
||||
some common elements will be generated by regular operations. The final continous array will be sorted.
|
||||
some common elements will be generated by regular operations. The final continuous array will be sorted.
|
||||
|
||||
This function is performance-critical, don't change it significantly without benchmarking
|
||||
SampleSet->PointSet conversion performance.
|
||||
*/
|
||||
let splitContinuousAndDiscreteForMinWeight = (
|
||||
sortedArray: array<float>,
|
||||
|
@ -320,34 +323,32 @@ module Floats = {
|
|||
let continuous: array<float> = []
|
||||
let discrete = FloatFloatMap.empty()
|
||||
|
||||
let flush = (cnt: int, value: float): unit => {
|
||||
if cnt >= minDiscreteWeight {
|
||||
FloatFloatMap.add(value, cnt->Belt.Int.toFloat, discrete)
|
||||
let addData = (count: int, value: float): unit => {
|
||||
if count >= minDiscreteWeight {
|
||||
FloatFloatMap.add(value, count->Belt.Int.toFloat, discrete)
|
||||
} else {
|
||||
for _ in 1 to cnt {
|
||||
let _ = continuous->Js.Array2.push(value)
|
||||
for _ in 1 to count {
|
||||
continuous->Js.Array2.push(value)->ignore
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if sortedArray->Js.Array2.length != 0 {
|
||||
let (finalCnt, finalValue) = sortedArray->Belt.Array.reduce(
|
||||
let (finalCount, finalValue) = sortedArray->Belt.Array.reduce(
|
||||
// initial prev value doesn't matter; if it collides with the first element of the array, flush won't do anything
|
||||
(0, 0.),
|
||||
((cnt, prev), element) => {
|
||||
((count, prev), element) => {
|
||||
if element == prev {
|
||||
(cnt + 1, prev)
|
||||
(count + 1, prev)
|
||||
} else {
|
||||
// new value, process previous ones
|
||||
flush(cnt, prev)
|
||||
addData(count, prev)
|
||||
(1, element)
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
// flush final values
|
||||
flush(finalCnt, finalValue)
|
||||
}
|
||||
addData(finalCount, finalValue)
|
||||
|
||||
(continuous, discrete)
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue
Block a user