Add multiple plotting

This commit is contained in:
Sam Nolan 2022-07-12 17:09:24 +10:00
parent af4423cc2e
commit 9cbeee0451
3 changed files with 135 additions and 37 deletions

View File

@ -4,6 +4,7 @@ import {
result, result,
distributionError, distributionError,
distributionErrorToString, distributionErrorToString,
squiggleExpression,
} from "@quri/squiggle-lang"; } from "@quri/squiggle-lang";
import { Vega } from "react-vega"; import { Vega } from "react-vega";
import { ErrorAlert } from "./Alert"; import { ErrorAlert } from "./Alert";
@ -23,16 +24,58 @@ export type DistributionPlottingSettings = {
showControls: boolean; showControls: boolean;
} & DistributionChartSpecOptions; } & DistributionChartSpecOptions;
export type Plot = {
distributions: Distribution[];
};
export type DistributionChartProps = { export type DistributionChartProps = {
distribution: Distribution; plot: Plot;
width?: number; width?: number;
height: number; height: number;
actions?: boolean; actions?: boolean;
} & DistributionPlottingSettings; } & DistributionPlottingSettings;
export function defaultPlot(distribution: Distribution): Plot {
return { distributions: [distribution] };
}
export function makePlot(expression: {
[key: string]: squiggleExpression;
}): Plot | void {
if (expression["distributions"].tag === "array") {
let distributions: Distribution[] = expression["distributions"].value
.map((x) => {
if (x.tag === "distribution") {
return x.value;
}
})
.filter((x): x is Distribution => x !== undefined);
return { distributions };
}
}
function all(arr: boolean[]): boolean {
return arr.reduce((x, y) => x && y, true);
}
function flattenResult<a, b>(x: result<a, b>[]): result<a[], b> {
if (x.length === 0) {
return { tag: "Ok", value: [] };
} else {
if (x[0].tag === "Error") {
return x[0];
} else {
let rest = flattenResult(x.splice(1));
if (rest.tag === "Error") {
return rest;
} else {
return { tag: "Ok", value: [x[0].value].concat(rest.value) };
}
}
}
}
export const DistributionChart: React.FC<DistributionChartProps> = (props) => { export const DistributionChart: React.FC<DistributionChartProps> = (props) => {
const { const {
distribution, plot,
height, height,
showSummary, showSummary,
width, width,
@ -47,19 +90,23 @@ export const DistributionChart: React.FC<DistributionChartProps> = (props) => {
React.useEffect(() => setLogX(logX), [logX]); React.useEffect(() => setLogX(logX), [logX]);
React.useEffect(() => setExpY(expY), [expY]); React.useEffect(() => setExpY(expY), [expY]);
const shape = distribution.pointSet();
const [sized] = useSize((size) => { const [sized] = useSize((size) => {
if (shape.tag === "Error") { let shapes = flattenResult(plot.distributions.map((x) => x.pointSet()));
if (shapes.tag === "Error") {
return ( return (
<ErrorAlert heading="Distribution Error"> <ErrorAlert heading="Distribution Error">
{distributionErrorToString(shape.value)} {distributionErrorToString(shapes.value)}
</ErrorAlert> </ErrorAlert>
); );
} }
const massBelow0 = const massBelow0 = all(
shape.value.continuous.some((x) => x.x <= 0) || shapes.value.map(
shape.value.discrete.some((x) => x.x <= 0); (shape) =>
shape.continuous.some((x) => x.x <= 0) ||
shape.discrete.some((x) => x.x <= 0)
)
);
const spec = buildVegaSpec(props); const spec = buildVegaSpec(props);
let widthProp = width ? width : size.width; let widthProp = width ? width : size.width;
@ -69,13 +116,20 @@ export const DistributionChart: React.FC<DistributionChartProps> = (props) => {
); );
widthProp = 20; widthProp = 20;
} }
let continuousPoints = shapes.value.flatMap((shape, i) =>
shape.continuous.map((point) => ({ ...point, name: i + 1 }))
);
let discretePoints = shapes.value.flatMap((shape, i) =>
shape.discrete.map((point) => ({ ...point, name: i + 1 }))
);
console.log(continuousPoints);
return ( return (
<div style={{ width: widthProp }}> <div style={{ width: widthProp }}>
{!(isLogX && massBelow0) ? ( {!(isLogX && massBelow0) ? (
<Vega <Vega
spec={spec} spec={spec}
data={{ con: shape.value.continuous, dis: shape.value.discrete }} data={{ con: continuousPoints, dis: discretePoints }}
width={widthProp - 10} width={widthProp - 10}
height={height} height={height}
actions={actions} actions={actions}
@ -86,7 +140,9 @@ export const DistributionChart: React.FC<DistributionChartProps> = (props) => {
</ErrorAlert> </ErrorAlert>
)} )}
<div className="flex justify-center"> <div className="flex justify-center">
{showSummary && <SummaryTable distribution={distribution} />} {showSummary && plot.distributions.length == 1 && (
<SummaryTable distribution={plot.distributions[0]} />
)}
</div> </div>
{showControls && ( {showControls && (
<div> <div>

View File

@ -8,6 +8,8 @@ import { NumberShower } from "./NumberShower";
import { import {
DistributionChart, DistributionChart,
DistributionPlottingSettings, DistributionPlottingSettings,
makePlot,
defaultPlot,
} from "./DistributionChart"; } from "./DistributionChart";
import { FunctionChart, FunctionChartSettings } from "./FunctionChart"; import { FunctionChart, FunctionChartSettings } from "./FunctionChart";
@ -102,7 +104,7 @@ export const SquiggleItem: React.FC<SquiggleItemProps> = ({
<div>{expression.value.toString()}</div> <div>{expression.value.toString()}</div>
) : null} ) : null}
<DistributionChart <DistributionChart
distribution={expression.value} plot={defaultPlot(expression.value)}
{...distributionPlotSettings} {...distributionPlotSettings}
height={height} height={height}
width={width} width={width}
@ -164,6 +166,17 @@ export const SquiggleItem: React.FC<SquiggleItemProps> = ({
</VariableBox> </VariableBox>
); );
case "record": case "record":
let plot = makePlot(expression.value);
if (plot) {
return (
<DistributionChart
plot={plot}
{...distributionPlotSettings}
height={height}
width={width}
/>
);
}
return ( return (
<VariableBox heading="Record" showTypes={showTypes}> <VariableBox heading="Record" showTypes={showTypes}>
<div className="space-y-3"> <div className="space-y-3">
@ -246,7 +259,7 @@ export const SquiggleItem: React.FC<SquiggleItemProps> = ({
<VariableBox heading="Module" showTypes={showTypes}> <VariableBox heading="Module" showTypes={showTypes}>
<div className="space-y-3"> <div className="space-y-3">
{Object.entries(expression.value) {Object.entries(expression.value)
.filter(([key, r]) => key !== "Math") .filter(([key, _]) => key !== "Math")
.map(([key, r]) => ( .map(([key, r]) => (
<div key={key} className="flex space-x-2"> <div key={key} className="flex space-x-2">
<div className="flex-none"> <div className="flex-none">

View File

@ -137,7 +137,21 @@ export function buildVegaSpec(
}, },
], ],
signals: [], signals: [],
scales: [xScale, expY ? expYScale : linearYScale], scales: [
xScale,
expY ? expYScale : linearYScale,
{
name: "color",
type: "ordinal",
domain: {
fields: [
{ data: "con", field: "name" },
{ data: "dis", field: "name" },
],
},
range: { scheme: "category20b" },
},
],
axes: [ axes: [
{ {
orient: "bottom", orient: "bottom",
@ -153,9 +167,21 @@ export function buildVegaSpec(
], ],
marks: [ marks: [
{ {
name: "group",
type: "group",
from: {
facet: {
name: "faceted_path_main",
data: "con",
groupby: ["name"],
},
},
marks: [
{
name: "distribution_charts",
type: "area", type: "area",
from: { from: {
data: "con", data: "faceted_path_main",
}, },
encode: { encode: {
update: { update: {
@ -173,7 +199,8 @@ export function buildVegaSpec(
value: 0, value: 0,
}, },
fill: { fill: {
value: color, field: "name",
scale: "color",
}, },
fillOpacity: { fillOpacity: {
value: 1, value: 1,
@ -181,6 +208,8 @@ export function buildVegaSpec(
}, },
}, },
}, },
],
},
{ {
type: "rect", type: "rect",
from: { from: {