Refactor specification to include discrete

This commit is contained in:
Sam Nolan 2022-07-13 14:15:07 +10:00
parent 9cbeee0451
commit 98ae0459c9
3 changed files with 157 additions and 161 deletions

View File

@ -5,6 +5,7 @@ import {
distributionError, distributionError,
distributionErrorToString, distributionErrorToString,
squiggleExpression, squiggleExpression,
resultMap,
} from "@quri/squiggle-lang"; } from "@quri/squiggle-lang";
import { Vega } from "react-vega"; import { Vega } from "react-vega";
import { ErrorAlert } from "./Alert"; import { ErrorAlert } from "./Alert";
@ -24,8 +25,10 @@ export type DistributionPlottingSettings = {
showControls: boolean; showControls: boolean;
} & DistributionChartSpecOptions; } & DistributionChartSpecOptions;
export type LabeledDistribution = { name: string; distribution: Distribution };
export type Plot = { export type Plot = {
distributions: Distribution[]; distributions: LabeledDistribution[];
}; };
export type DistributionChartProps = { export type DistributionChartProps = {
@ -36,19 +39,29 @@ export type DistributionChartProps = {
} & DistributionPlottingSettings; } & DistributionPlottingSettings;
export function defaultPlot(distribution: Distribution): Plot { export function defaultPlot(distribution: Distribution): Plot {
return { distributions: [distribution] }; return { distributions: [{ name: "default", distribution }] };
} }
export function makePlot(expression: { export function makePlot(expression: {
[key: string]: squiggleExpression; [key: string]: squiggleExpression;
}): Plot | void { }): Plot | void {
if (expression["distributions"].tag === "array") { if (expression["distributions"].tag === "array") {
let distributions: Distribution[] = expression["distributions"].value let distributions: LabeledDistribution[] = expression["distributions"].value
.map((x) => { .map((x) => {
if (x.tag === "distribution") { if (
return x.value; x.tag === "record" &&
x.value["name"] &&
x.value["name"].tag === "string" &&
x.value["distribution"] &&
x.value["distribution"].tag === "distribution"
) {
return {
name: x.value["name"].value,
distribution: x.value["distribution"].value,
};
} }
}) })
.filter((x): x is Distribution => x !== undefined); .filter((x): x is LabeledDistribution => x !== undefined);
return { distributions }; return { distributions };
} }
} }
@ -91,7 +104,15 @@ export const DistributionChart: React.FC<DistributionChartProps> = (props) => {
React.useEffect(() => setExpY(expY), [expY]); React.useEffect(() => setExpY(expY), [expY]);
const [sized] = useSize((size) => { const [sized] = useSize((size) => {
let shapes = flattenResult(plot.distributions.map((x) => x.pointSet())); let shapes = flattenResult(
plot.distributions.map((x) =>
resultMap(x.distribution.pointSet(), (shape) => ({
name: x.name,
continuous: shape.continuous,
discrete: shape.discrete,
}))
)
);
if (shapes.tag === "Error") { if (shapes.tag === "Error") {
return ( return (
<ErrorAlert heading="Distribution Error"> <ErrorAlert heading="Distribution Error">
@ -116,20 +137,17 @@ export const DistributionChart: React.FC<DistributionChartProps> = (props) => {
); );
widthProp = 20; widthProp = 20;
} }
let continuousPoints = shapes.value.flatMap((shape, i) => const domain = shapes.value.flatMap((shape) =>
shape.continuous.map((point) => ({ ...point, name: i + 1 })) shape.discrete.concat(shape.continuous)
);
let discretePoints = shapes.value.flatMap((shape, i) =>
shape.discrete.map((point) => ({ ...point, name: i + 1 }))
); );
console.log(shapes.value);
console.log(continuousPoints);
return ( return (
<div style={{ width: widthProp }}> <div style={{ width: widthProp }}>
{!(isLogX && massBelow0) ? ( {!(isLogX && massBelow0) ? (
<Vega <Vega
spec={spec} spec={spec}
data={{ con: continuousPoints, dis: discretePoints }} data={{ data: shapes.value, domain }}
width={widthProp - 10} width={widthProp - 10}
height={height} height={height}
actions={actions} actions={actions}
@ -141,7 +159,7 @@ export const DistributionChart: React.FC<DistributionChartProps> = (props) => {
)} )}
<div className="flex justify-center"> <div className="flex justify-center">
{showSummary && plot.distributions.length == 1 && ( {showSummary && plot.distributions.length == 1 && (
<SummaryTable distribution={plot.distributions[0]} /> <SummaryTable distribution={plot.distributions[0].distribution} />
)} )}
</div> </div>
{showControls && ( {showControls && (

View File

@ -16,6 +16,7 @@ import * as percentilesSpec from "../vega-specs/spec-percentiles.json";
import { import {
DistributionChart, DistributionChart,
DistributionPlottingSettings, DistributionPlottingSettings,
defaultPlot,
} from "./DistributionChart"; } from "./DistributionChart";
import { NumberShower } from "./NumberShower"; import { NumberShower } from "./NumberShower";
import { ErrorAlert } from "./Alert"; import { ErrorAlert } from "./Alert";
@ -177,7 +178,7 @@ export const FunctionChart1Dist: React.FC<FunctionChart1DistProps> = ({
let showChart = let showChart =
mouseItem.tag === "Ok" && mouseItem.value.tag === "distribution" ? ( mouseItem.tag === "Ok" && mouseItem.value.tag === "distribution" ? (
<DistributionChart <DistributionChart
distribution={mouseItem.value.value} plot={defaultPlot(mouseItem.value.value)}
width={400} width={400}
height={50} height={50}
{...distributionPlotSettings} {...distributionPlotSettings}

View File

@ -25,36 +25,14 @@ export let linearXScale: LinearScale = {
range: "width", range: "width",
zero: false, zero: false,
nice: false, nice: false,
domain: { domain: { data: "domain", field: "x" },
fields: [
{
data: "con",
field: "x",
},
{
data: "dis",
field: "x",
},
],
},
}; };
export let linearYScale: LinearScale = { export let linearYScale: LinearScale = {
name: "yscale", name: "yscale",
type: "linear", type: "linear",
range: "height", range: "height",
zero: false, zero: false,
domain: { domain: { data: "domain", field: "y" },
fields: [
{
data: "con",
field: "y",
},
{
data: "dis",
field: "y",
},
],
},
}; };
export let logXScale: LogScale = { export let logXScale: LogScale = {
@ -65,18 +43,7 @@ export let logXScale: LogScale = {
base: 10, base: 10,
nice: false, nice: false,
clamp: true, clamp: true,
domain: { domain: { data: "domain", field: "x" },
fields: [
{
data: "con",
field: "x",
},
{
data: "dis",
field: "x",
},
],
},
}; };
export let expYScale: PowScale = { export let expYScale: PowScale = {
@ -86,24 +53,13 @@ export let expYScale: PowScale = {
range: "height", range: "height",
zero: false, zero: false,
nice: false, nice: false,
domain: { domain: { data: "domain", field: "y" },
fields: [
{
data: "con",
field: "y",
},
{
data: "dis",
field: "y",
},
],
},
}; };
export function buildVegaSpec( export function buildVegaSpec(
specOptions: DistributionChartSpecOptions specOptions: DistributionChartSpecOptions
): VisualizationSpec { ): VisualizationSpec {
let { const {
format = ".9~s", format = ".9~s",
color = "#739ECC", color = "#739ECC",
title, title,
@ -130,10 +86,10 @@ export function buildVegaSpec(
padding: 5, padding: 5,
data: [ data: [
{ {
name: "con", name: "data",
}, },
{ {
name: "dis", name: "domain",
}, },
], ],
signals: [], signals: [],
@ -144,12 +100,10 @@ export function buildVegaSpec(
name: "color", name: "color",
type: "ordinal", type: "ordinal",
domain: { domain: {
fields: [ data: "data",
{ data: "con", field: "name" }, field: "name",
{ data: "dis", field: "name" },
],
}, },
range: { scheme: "category20b" }, range: { scheme: "category10" },
}, },
], ],
axes: [ axes: [
@ -167,109 +121,132 @@ export function buildVegaSpec(
], ],
marks: [ marks: [
{ {
name: "group", name: "all_distributions",
type: "group", type: "group",
from: { from: {
facet: { facet: {
name: "faceted_path_main", name: "distribution_facet",
data: "con", data: "data",
groupby: ["name"], groupby: ["name"],
}, },
}, },
marks: [ marks: [
{ {
name: "distribution_charts", name: "continuous_distribution",
type: "area", type: "group",
from: { from: {
data: "faceted_path_main", facet: {
}, name: "continuous_facet",
encode: { data: "distribution_facet",
update: { field: "continuous",
interpolate: { value: "linear" },
x: {
scale: "xscale",
field: "x",
},
y: {
scale: "yscale",
field: "y",
},
y2: {
scale: "yscale",
value: 0,
},
fill: {
field: "name",
scale: "color",
},
fillOpacity: {
value: 1,
},
}, },
}, },
encode: {
update: {},
},
marks: [
{
name: "continuous_area",
type: "area",
from: {
data: "continuous_facet",
},
encode: {
update: {
interpolate: { value: "linear" },
x: {
scale: "xscale",
field: "x",
},
y: {
scale: "yscale",
field: "y",
},
fill: {
scale: "color",
field: { parent: "name" },
},
y2: {
scale: "yscale",
value: 0,
},
fillOpacity: {
value: 1,
},
},
},
},
],
},
{
name: "discrete_distribution",
type: "group",
from: {
facet: {
name: "discrete_facet",
data: "distribution_facet",
field: "discrete",
},
},
marks: [
{
type: "rect",
from: {
data: "discrete_facet",
},
encode: {
enter: {
width: {
value: 1,
},
},
update: {
x: {
scale: "xscale",
field: "x",
},
y: {
scale: "yscale",
field: "y",
},
y2: {
scale: "yscale",
value: 0,
},
},
},
},
{
type: "symbol",
from: {
data: "discrete_facet",
},
encode: {
enter: {
shape: {
value: "circle",
},
size: [{ value: 100 }],
tooltip: {
signal: "datum.y",
},
},
update: {
x: {
scale: "xscale",
field: "x",
},
y: {
scale: "yscale",
field: "y",
},
},
},
},
],
}, },
], ],
}, },
{
type: "rect",
from: {
data: "dis",
},
encode: {
enter: {
width: {
value: 1,
},
},
update: {
x: {
scale: "xscale",
field: "x",
},
y: {
scale: "yscale",
field: "y",
},
y2: {
scale: "yscale",
value: 0,
},
fill: {
value: "#2f65a7",
},
},
},
},
{
type: "symbol",
from: {
data: "dis",
},
encode: {
enter: {
shape: {
value: "circle",
},
size: [{ value: 100 }],
tooltip: {
signal: "datum.y",
},
},
update: {
x: {
scale: "xscale",
field: "x",
},
y: {
scale: "yscale",
field: "y",
},
fill: {
value: "#1e4577",
},
},
},
},
], ],
}; };
if (title) { if (title) {