Add plot function

This commit is contained in:
Sam Nolan 2022-10-10 17:37:21 +11:00
parent eaa7d38428
commit 62f735efcb
15 changed files with 261 additions and 174 deletions

View File

@ -3,8 +3,9 @@ import {
SqDistribution,
result,
SqDistributionError,
LabeledDistribution,
resultMap,
SqRecord,
SqPlot,
environment,
SqDistributionTag,
} from "@quri/squiggle-lang";
@ -17,7 +18,6 @@ import {
DistributionChartSpecOptions,
} from "../lib/distributionSpecBuilder";
import { NumberShower } from "./NumberShower";
import { Plot, parsePlot } from "../lib/plotParser";
import { flattenResult } from "../lib/utility";
import { hasMassBelowZero } from "../lib/distributionUtils";
@ -28,27 +28,15 @@ export type DistributionPlottingSettings = {
} & DistributionChartSpecOptions;
export type DistributionChartProps = {
plot: Plot;
environment: environment;
width?: number;
height: number;
xAxisType?: "number" | "dateTime";
} & DistributionPlottingSettings;
export function defaultPlot(distribution: SqDistribution): Plot {
return { distributions: [{ name: "default", distribution }] };
}
export function makePlot(record: SqRecord): Plot | void {
const plotResult = parsePlot(record);
if (plotResult.tag === "Ok") {
return plotResult.value;
}
}
} & DistributionPlottingSettings &
({ plot: SqPlot } | { distribution: SqDistribution });
export const DistributionChart: React.FC<DistributionChartProps> = (props) => {
const {
plot,
environment,
height,
showSummary,
@ -57,8 +45,14 @@ export const DistributionChart: React.FC<DistributionChartProps> = (props) => {
actions = false,
} = props;
const [sized] = useSize((size) => {
const shapes = flattenResult(
plot.distributions.map((x) =>
let distributions: LabeledDistribution[];
if ("plot" in props) {
distributions = props.plot.getDistributions();
} else {
distributions = [{ name: "default", distribution: props.distribution }];
}
let shapes = flattenResult(
distributions.map((x) =>
resultMap(x.distribution.pointSet(environment), (pointSet) => ({
name: x.name,
// color: x.color, // not supported yet
@ -77,7 +71,7 @@ export const DistributionChart: React.FC<DistributionChartProps> = (props) => {
// if this is a sample set, include the samples
const samples: number[] = [];
for (const { distribution } of plot?.distributions) {
for (const { distribution } of distributions) {
if (distribution.tag === SqDistributionTag.SampleSet) {
samples.push(...distribution.value());
}
@ -126,9 +120,9 @@ export const DistributionChart: React.FC<DistributionChartProps> = (props) => {
/>
)}
<div className="flex justify-center">
{showSummary && plot.distributions.length === 1 && (
{showSummary && distributions.length === 1 && (
<SummaryTable
distribution={plot.distributions[0].distribution}
distribution={distributions[0].distribution}
environment={environment}
/>
)}

View File

@ -15,7 +15,6 @@ import * as percentilesSpec from "../vega-specs/spec-percentiles.json";
import {
DistributionChart,
DistributionPlottingSettings,
defaultPlot,
} from "./DistributionChart";
import { NumberShower } from "./NumberShower";
import { ErrorAlert } from "./Alert";
@ -184,7 +183,7 @@ export const FunctionChart1Dist: React.FC<FunctionChart1DistProps> = ({
mouseItem.tag === "Ok" &&
mouseItem.value.tag === SqValueTag.Distribution ? (
<DistributionChart
plot={defaultPlot(mouseItem.value.value)}
distribution={mouseItem.value.value}
environment={environment}
width={400}
height={50}
@ -194,7 +193,7 @@ export const FunctionChart1Dist: React.FC<FunctionChart1DistProps> = ({
let getPercentilesMemoized = React.useMemo(
() => getPercentiles({ chartSettings, fn, environment }),
[environment, fn]
[chartSettings, environment, fn]
);
return (

View File

@ -1,7 +1,7 @@
import React, { useContext } from "react";
import { SqDistributionTag, SqValue, SqValueTag } from "@quri/squiggle-lang";
import { NumberShower } from "../NumberShower";
import { DistributionChart, defaultPlot, makePlot } from "../DistributionChart";
import { DistributionChart } from "../DistributionChart";
import { FunctionChart } from "../FunctionChart";
import clsx from "clsx";
import { VariableBox } from "./VariableBox";
@ -104,7 +104,7 @@ export const ExpressionViewer: React.FC<Props> = ({ value, width }) => {
{(settings) => {
return (
<DistributionChart
plot={defaultPlot(value.value)}
distribution={value.value}
environment={settings.environment}
{...settings.distributionPlotSettings}
height={settings.height}
@ -219,63 +219,61 @@ export const ExpressionViewer: React.FC<Props> = ({ value, width }) => {
</VariableBox>
);
}
case SqValueTag.Plot:
const plot = value.value;
return (
<VariableBox
value={value}
heading="Plot"
renderSettingsMenu={({ onChange }) => {
let disableLogX = plot.getDistributions().some((x) => {
let pointSet = x.distribution.pointSet(
getMergedSettings(value.location).environment
);
return (
pointSet.tag === "Ok" &&
hasMassBelowZero(pointSet.value.asShape())
);
});
return (
<ItemSettingsMenu
value={value}
onChange={onChange}
disableLogX={disableLogX}
withFunctionSettings={false}
/>
);
}}
>
{(settings) => {
return (
<DistributionChart
plot={plot}
environment={settings.environment}
{...settings.distributionPlotSettings}
height={settings.height}
width={width}
/>
);
}}
</VariableBox>
);
case SqValueTag.Record:
const plot = makePlot(value.value);
if (plot) {
return (
<VariableBox
value={value}
heading="Plot"
renderSettingsMenu={({ onChange }) => {
let disableLogX = plot.distributions.some((x) => {
let pointSet = x.distribution.pointSet(
getMergedSettings(value.location).environment
);
return (
pointSet.tag === "Ok" &&
hasMassBelowZero(pointSet.value.asShape())
);
});
return (
<ItemSettingsMenu
value={value}
onChange={onChange}
disableLogX={disableLogX}
withFunctionSettings={false}
return (
<VariableList value={value} heading="Record">
{(_) =>
value.value
.entries()
.map(([key, r]) => (
<ExpressionViewer
key={key}
value={r}
width={width !== undefined ? width - 20 : width}
/>
);
}}
>
{(settings) => {
return (
<DistributionChart
plot={plot}
environment={settings.environment}
{...settings.distributionPlotSettings}
height={settings.height}
width={width}
/>
);
}}
</VariableBox>
);
} else {
return (
<VariableList value={value} heading="Record">
{(_) =>
value.value
.entries()
.map(([key, r]) => (
<ExpressionViewer
key={key}
value={r}
width={width !== undefined ? width - 20 : width}
/>
))
}
</VariableList>
);
}
))
}
</VariableList>
);
case SqValueTag.Array:
return (
<VariableList value={value} heading="Array">

View File

@ -1,83 +0,0 @@
import * as yup from "yup";
import {
SqValue,
SqValueTag,
SqDistribution,
result,
SqRecord,
} from "@quri/squiggle-lang";
export type LabeledDistribution = {
name: string;
distribution: SqDistribution;
color?: string;
};
export type Plot = {
distributions: LabeledDistribution[];
};
function error<a, b>(err: b): result<a, b> {
return { tag: "Error", value: err };
}
function ok<a, b>(x: a): result<a, b> {
return { tag: "Ok", value: x };
}
const schema = yup
.object()
.noUnknown()
.strict()
.shape({
distributions: yup
.array()
.required()
.of(
yup.object().required().shape({
name: yup.string().required(),
distribution: yup.mixed().required(),
})
),
});
type JsonObject =
| string
| { [key: string]: JsonObject }
| JsonObject[]
| SqDistribution;
function toJson(val: SqValue): JsonObject {
if (val.tag === SqValueTag.String) {
return val.value;
} else if (val.tag === SqValueTag.Record) {
return toJsonRecord(val.value);
} else if (val.tag === SqValueTag.Array) {
return val.value.getValues().map(toJson);
} else if (val.tag === SqValueTag.Distribution) {
return val.value;
} else {
throw new Error("Could not parse object of type " + val.tag);
}
}
function toJsonRecord(val: SqRecord): JsonObject {
let recordObject: JsonObject = {};
val.entries().forEach(([key, value]) => (recordObject[key] = toJson(value)));
return recordObject;
}
export function parsePlot(record: SqRecord): result<Plot, string> {
try {
const plotRecord = schema.validateSync(toJsonRecord(record));
if (plotRecord.distributions) {
return ok({ distributions: plotRecord.distributions.map((x) => x) });
} else {
// I have no idea why yup's typings thinks this is possible
return error("no distributions field. Should never get here");
}
} catch (e) {
const message = e instanceof Error ? e.message : "Unknown error";
return error(message);
}
}

View File

@ -0,0 +1,25 @@
import * as RSPlot from "../rescript/ForTS/ForTS_SquiggleValue/ForTS_SquiggleValue_Plot.gen";
import { SqDistribution, wrapDistribution } from "./SqDistribution";
import { SqValueLocation } from "./SqValueLocation";
type T = RSPlot.squiggleValue_Plot;
export type LabeledDistribution = {
name: string;
distribution: SqDistribution;
};
export class SqPlot {
constructor(private _value: T, public location: SqValueLocation) {}
getDistributions(): LabeledDistribution[] {
return this._value.distributions.map((v: RSPlot.labeledDistribution) => ({
...v,
distribution: wrapDistribution(v.distribution),
}));
}
toString() {
return RSPlot.toString(this._value);
}
}

View File

@ -4,6 +4,7 @@ import { wrapDistribution } from "./SqDistribution";
import { SqLambda } from "./SqLambda";
import { SqLambdaDeclaration } from "./SqLambdaDeclaration";
import { SqRecord } from "./SqRecord";
import { SqPlot } from "./SqPlot";
import { SqArray } from "./SqArray";
import { SqValueLocation } from "./SqValueLocation";
@ -91,6 +92,14 @@ export class SqNumberValue extends SqAbstractValue {
}
}
export class SqPlotValue extends SqAbstractValue {
tag = Tag.Plot as const;
get value() {
return new SqPlot(this.valueMethod(RSValue.getPlot), this.location);
}
}
export class SqRecordValue extends SqAbstractValue {
tag = Tag.Record as const;
@ -131,6 +140,7 @@ const tagToClass = {
[Tag.Distribution]: SqDistributionValue,
[Tag.Lambda]: SqLambdaValue,
[Tag.Number]: SqNumberValue,
[Tag.Plot]: SqPlotValue,
[Tag.Record]: SqRecordValue,
[Tag.String]: SqStringValue,
[Tag.TimeDuration]: SqTimeDurationValue,
@ -148,6 +158,7 @@ export type SqValue =
| SqLambdaValue
| SqNumberValue
| SqRecordValue
| SqPlotValue
| SqStringValue
| SqTimeDurationValue
| SqVoidValue;

View File

@ -6,6 +6,7 @@ export { result } from "../rescript/ForTS/ForTS_Result_tag";
export { SqDistribution, SqDistributionTag } from "./SqDistribution";
export { SqDistributionError } from "./SqDistributionError";
export { SqRecord } from "./SqRecord";
export { SqPlot, LabeledDistribution } from "./SqPlot";
export { SqLambda } from "./SqLambda";
export { SqProject };
export { SqValue, SqValueTag };
@ -14,7 +15,7 @@ export {
defaultEnvironment,
} from "../rescript/ForTS/ForTS_Distribution/ForTS_Distribution.gen";
export { SqError, SqFrame, SqLocation } from "./SqError";
export { SqShape } from "./SqPointSetDist";
export { SqShape, SqPoint } from "./SqPointSetDist";
export { resultMap } from "./types";

View File

@ -0,0 +1,103 @@
open FunctionRegistry_Core
open FunctionRegistry_Helpers
let nameSpace = "Plot"
module Internals = {
let parseString = (a: Reducer_T.value): result<string, SqError.Message.t> => {
switch a {
| IEvString(s) => Ok(s)
| _ => Error(SqError.Message.REOther("Expected to be a string"))
}
}
let parseDistributionOrNumber = (a: Reducer_T.value): result<
GenericDist.t,
SqError.Message.t,
> => {
switch a {
| IEvDistribution(s) => Ok(s)
| IEvNumber(s) => Ok(GenericDist.fromFloat(s))
| _ => Error(SqError.Message.REOther("Expected to be a distribution"))
}
}
let parseArray = (
parser: Reducer_T.value => result<'a, SqError.Message.t>,
a: Reducer_T.value,
): result<array<'a>, SqError.Message.t> => {
switch a {
| IEvArray(x) => x->E.A2.fmap(parser)->E.A.R.firstErrorOrOpen
| _ => Error(SqError.Message.REOther("Expected to be an array"))
}
}
let parseRecord = (
parser: Reducer_T.map => result<'b, SqError.Message.t>,
a: Reducer_T.value,
): result<'b, SqError.Message.t> => {
switch a {
| IEvRecord(x) => parser(x)
| _ => Error(SqError.Message.REOther("Expected to be an array"))
}
}
let parseField = (
a: Reducer_T.map,
key: string,
parser: Reducer_T.value => result<'a, SqError.Message.t>,
): result<'a, SqError.Message.t> => {
switch Belt.Map.String.get(a, key) {
| Some(x) => parser(x)
| None => Error(SqError.Message.REOther("expected field " ++ key ++ " in plot dictionary."))
}
}
let parseLabeledDistribution = (a: Reducer_T.map): result<
Reducer_T.labeledDistribution,
SqError.Message.t,
> => {
let name = parseField(a, "name", parseString)
let distribution = parseField(a, "value", parseDistributionOrNumber)
switch E.R.merge(name, distribution) {
| Ok(name, distribution) => Ok({name: name, distribution: distribution})
| Error(err) => Error(err)
}
}
let parsePlotValue = (a: Reducer_T.map): result<Reducer_T.plotValue, SqError.Message.t> => {
parseField(a, "show", parseArray(parseRecord(parseLabeledDistribution)))->E.R2.fmap(dists => {
let plot: Reducer_T.plotValue = {distributions: dists}
plot
})
}
let dist = (a: Reducer_T.map): result<Reducer_T.value, SqError.Message.t> =>
E.R2.fmap(parsePlotValue(a), x => Reducer_T.IEvPlot(x))
}
let library = [
Function.make(
~name="dist",
~nameSpace,
~requiresNamespace=true,
~output=EvtPlot,
~examples=[
`Plot.dist({show: [{name: "Control", value: 1 to 2}, {name: "Treatment", value: 1.5 to 2.5}]}) `,
],
~definitions=[
FnDefinition.make(
~name="dist",
~inputs=[FRTypeDict(FRTypeAny)],
~run=(inputs, _, _) => {
switch inputs {
| [IEvRecord(plot)] => Internals.dist(plot)
| _ => impossibleError->Error
}
},
(),
),
],
(),
),
]

View File

@ -6,6 +6,7 @@ type error = SqError.t //use
type squiggleValue_Declaration = ForTS_SquiggleValue_Declaration.squiggleValue_Declaration //use
type squiggleValue_Distribution = ForTS_SquiggleValue_Distribution.squiggleValue_Distribution //use
type squiggleValue_Lambda = ForTS_SquiggleValue_Lambda.squiggleValue_Lambda //use
@genType type squiggleValue_Plot = Reducer_T.plotValue //use
// Return values are kept as they are if they are JavaScript types.
@ -30,6 +31,9 @@ external svtLambda_: string = "Lambda"
@module("./ForTS_SquiggleValue_tag") @scope("squiggleValueTag")
external svtNumber_: string = "Number"
@module("./ForTS_SquiggleValue_tag") @scope("squiggleValueTag")
external svtPlot_: string = "Plot"
@module("./ForTS_SquiggleValue_tag") @scope("squiggleValueTag")
external svtRecord_: string = "Record"
@ -57,6 +61,7 @@ let getTag = (variant: squiggleValue): squiggleValueTag =>
| IEvDistribution(_) => svtDistribution_->castEnum
| IEvLambda(_) => svtLambda_->castEnum
| IEvNumber(_) => svtNumber_->castEnum
| IEvPlot(_) => svtPlot_->castEnum
| IEvRecord(_) => svtRecord_->castEnum
| IEvString(_) => svtString_->castEnum
| IEvTimeDuration(_) => svtTimeDuration_->castEnum
@ -122,6 +127,13 @@ let getNumber = (variant: squiggleValue): option<float> =>
| _ => None
}
@genType
let getPlot = (variant: squiggleValue): option<squiggleValue_Plot> =>
switch variant {
| IEvPlot(value) => value->Some
| _ => None
}
@genType
let getRecord = (variant: squiggleValue): option<squiggleValue_Record> =>
switch variant {

View File

@ -0,0 +1,6 @@
type squiggleValue = ForTS_SquiggleValue.squiggleValue //use
@genType type squiggleValue_Plot = ForTS_SquiggleValue.squiggleValue_Plot //re-export recursive type
@genType type labeledDistribution = Reducer_T.labeledDistribution // use
@genType
let toString = (v: squiggleValue_Plot) => Reducer_Value.toStringPlot(v)

View File

@ -6,6 +6,7 @@ export enum squiggleValueTag {
Distribution = "Distribution",
Lambda = "Lambda",
Number = "Number",
Plot = "Plot",
Record = "Record",
String = "String",
TimeDuration = "TimeDuration",

View File

@ -1,18 +1,19 @@
let fnList = Belt.Array.concatMany([
FR_Builtin.library,
FR_Danger.library,
FR_Date.library,
FR_Dict.library,
FR_Dist.library,
FR_Danger.library,
FR_Fn.library,
FR_Sampleset.library,
FR_List.library,
FR_Number.library,
FR_Pointset.library,
FR_Scoring.library,
FR_GenericDist.library,
FR_Units.library,
FR_Date.library,
FR_List.library,
FR_Math.library,
FR_Number.library,
FR_Plot.library,
FR_Pointset.library,
FR_Sampleset.library,
FR_Scoring.library,
FR_Units.library,
])
let registry = FunctionRegistry_Core.Registry.make(fnList)

View File

@ -9,10 +9,12 @@ type rec value =
| IEvDistribution(DistributionTypes.genericDist)
| IEvLambda(lambdaValue)
| IEvNumber(float)
| IEvPlot(plotValue)
| IEvRecord(map)
| IEvString(string)
| IEvTimeDuration(float)
| IEvVoid
@genType.opaque and arrayValue = array<value>
@genType.opaque and map = Belt.Map.String.t<value>
and lambdaBody = (array<value>, context, reducerFn) => value
@ -66,4 +68,12 @@ and context = {
and reducerFn = (expression, context) => (value, context)
@genType and plotValue = {distributions: array<labeledDistribution>}
@genType
and labeledDistribution = {
name: string,
distribution: DistributionTypes.genericDist,
}
let topFrameName = "<top>"

View File

@ -14,6 +14,7 @@ let rec toString = (aValue: T.value) =>
| IEvDistribution(dist) => toStringDistribution(dist)
| IEvLambda(lambdaValue) => toStringLambda(lambdaValue)
| IEvNumber(aNumber) => toStringNumber(aNumber)
| IEvPlot(aPlot) => toStringPlot(aPlot)
| IEvRecord(aMap) => aMap->toStringRecord
| IEvString(aString) => toStringString(aString)
| IEvTimeDuration(t) => toStringTimeDuration(t)
@ -35,6 +36,10 @@ and toStringLambda = (lambdaValue: T.lambdaValue) => {
}
}
and toStringNumber = aNumber => Js.String.make(aNumber)
and toStringPlot = aPlot => {
let chartNames = E.A.fmap((x: Reducer_T.labeledDistribution) => x.name, aPlot.distributions)
`Plot showing ${Js.Array2.toString(chartNames)}`
}
and toStringRecord = aMap => aMap->toStringMap
and toStringString = aString => `'${aString}'`
and toStringSymbol = aString => `:${aString}`
@ -59,6 +64,7 @@ let toStringWithType = (aValue: T.value) =>
| IEvDistribution(_) => `Distribution::${toString(aValue)}`
| IEvLambda(_) => `Lambda::${toString(aValue)}`
| IEvNumber(_) => `Number::${toString(aValue)}`
| IEvPlot(_) => `Plot::${toString(aValue)}`
| IEvRecord(_) => `Record::${toString(aValue)}`
| IEvString(_) => `String::${toString(aValue)}`
| IEvTimeDuration(_) => `Date::${toString(aValue)}`
@ -91,6 +97,7 @@ type internalExpressionValueType =
| EvtDistribution
| EvtLambda
| EvtNumber
| EvtPlot
| EvtRecord
| EvtString
| EvtTimeDuration
@ -109,6 +116,7 @@ let valueToValueType = (value: T.value) =>
| IEvDistribution(_) => EvtDistribution
| IEvLambda(_) => EvtLambda
| IEvNumber(_) => EvtNumber
| IEvPlot(_) => EvtPlot
| IEvRecord(_) => EvtRecord
| IEvString(_) => EvtString
| IEvTimeDuration(_) => EvtTimeDuration
@ -129,6 +137,7 @@ let valueTypeToString = (valueType: internalExpressionValueType): string =>
| EvtDistribution => `Distribution`
| EvtLambda => `Lambda`
| EvtNumber => `Number`
| EvtPlot => `Plot`
| EvtRecord => `Record`
| EvtString => `String`
| EvtTimeDuration => `Duration`