From 9cbeee04515c7f290a80ff39981c51d4ffbfef59 Mon Sep 17 00:00:00 2001 From: Sam Nolan Date: Tue, 12 Jul 2022 17:09:24 +1000 Subject: [PATCH] Add multiple plotting --- .../src/components/DistributionChart.tsx | 76 +++++++++++++++--- .../src/components/SquiggleItem.tsx | 17 +++- .../src/lib/distributionSpecBuilder.ts | 79 +++++++++++++------ 3 files changed, 135 insertions(+), 37 deletions(-) diff --git a/packages/components/src/components/DistributionChart.tsx b/packages/components/src/components/DistributionChart.tsx index af644d29..0b9535b2 100644 --- a/packages/components/src/components/DistributionChart.tsx +++ b/packages/components/src/components/DistributionChart.tsx @@ -4,6 +4,7 @@ import { result, distributionError, distributionErrorToString, + squiggleExpression, } from "@quri/squiggle-lang"; import { Vega } from "react-vega"; import { ErrorAlert } from "./Alert"; @@ -23,16 +24,58 @@ export type DistributionPlottingSettings = { showControls: boolean; } & DistributionChartSpecOptions; +export type Plot = { + distributions: Distribution[]; +}; + export type DistributionChartProps = { - distribution: Distribution; + plot: Plot; width?: number; height: number; actions?: boolean; } & DistributionPlottingSettings; +export function defaultPlot(distribution: Distribution): Plot { + return { distributions: [distribution] }; +} +export function makePlot(expression: { + [key: string]: squiggleExpression; +}): Plot | void { + if (expression["distributions"].tag === "array") { + let distributions: Distribution[] = expression["distributions"].value + .map((x) => { + if (x.tag === "distribution") { + return x.value; + } + }) + .filter((x): x is Distribution => x !== undefined); + return { distributions }; + } +} +function all(arr: boolean[]): boolean { + return arr.reduce((x, y) => x && y, true); +} + +function flattenResult(x: result[]): result { + if (x.length === 0) { + return { tag: "Ok", value: [] }; + } else { + if (x[0].tag === "Error") { + return x[0]; + } else { + let rest = flattenResult(x.splice(1)); + if (rest.tag === "Error") { + return rest; + } else { + return { tag: "Ok", value: [x[0].value].concat(rest.value) }; + } + } + } +} + export const DistributionChart: React.FC = (props) => { const { - distribution, + plot, height, showSummary, width, @@ -47,19 +90,23 @@ export const DistributionChart: React.FC = (props) => { React.useEffect(() => setLogX(logX), [logX]); React.useEffect(() => setExpY(expY), [expY]); - const shape = distribution.pointSet(); const [sized] = useSize((size) => { - if (shape.tag === "Error") { + let shapes = flattenResult(plot.distributions.map((x) => x.pointSet())); + if (shapes.tag === "Error") { return ( - {distributionErrorToString(shape.value)} + {distributionErrorToString(shapes.value)} ); } - const massBelow0 = - shape.value.continuous.some((x) => x.x <= 0) || - shape.value.discrete.some((x) => x.x <= 0); + const massBelow0 = all( + shapes.value.map( + (shape) => + shape.continuous.some((x) => x.x <= 0) || + shape.discrete.some((x) => x.x <= 0) + ) + ); const spec = buildVegaSpec(props); let widthProp = width ? width : size.width; @@ -69,13 +116,20 @@ export const DistributionChart: React.FC = (props) => { ); widthProp = 20; } + let continuousPoints = shapes.value.flatMap((shape, i) => + shape.continuous.map((point) => ({ ...point, name: i + 1 })) + ); + let discretePoints = shapes.value.flatMap((shape, i) => + shape.discrete.map((point) => ({ ...point, name: i + 1 })) + ); + console.log(continuousPoints); return (
{!(isLogX && massBelow0) ? ( = (props) => { )}
- {showSummary && } + {showSummary && plot.distributions.length == 1 && ( + + )}
{showControls && (
diff --git a/packages/components/src/components/SquiggleItem.tsx b/packages/components/src/components/SquiggleItem.tsx index 48a8a0fb..09b9fa27 100644 --- a/packages/components/src/components/SquiggleItem.tsx +++ b/packages/components/src/components/SquiggleItem.tsx @@ -8,6 +8,8 @@ import { NumberShower } from "./NumberShower"; import { DistributionChart, DistributionPlottingSettings, + makePlot, + defaultPlot, } from "./DistributionChart"; import { FunctionChart, FunctionChartSettings } from "./FunctionChart"; @@ -102,7 +104,7 @@ export const SquiggleItem: React.FC = ({
{expression.value.toString()}
) : null} = ({ ); case "record": + let plot = makePlot(expression.value); + if (plot) { + return ( + + ); + } return (
@@ -246,7 +259,7 @@ export const SquiggleItem: React.FC = ({
{Object.entries(expression.value) - .filter(([key, r]) => key !== "Math") + .filter(([key, _]) => key !== "Math") .map(([key, r]) => (
diff --git a/packages/components/src/lib/distributionSpecBuilder.ts b/packages/components/src/lib/distributionSpecBuilder.ts index 4286dbdb..515a1b9f 100644 --- a/packages/components/src/lib/distributionSpecBuilder.ts +++ b/packages/components/src/lib/distributionSpecBuilder.ts @@ -137,7 +137,21 @@ export function buildVegaSpec( }, ], signals: [], - scales: [xScale, expY ? expYScale : linearYScale], + scales: [ + xScale, + expY ? expYScale : linearYScale, + { + name: "color", + type: "ordinal", + domain: { + fields: [ + { data: "con", field: "name" }, + { data: "dis", field: "name" }, + ], + }, + range: { scheme: "category20b" }, + }, + ], axes: [ { orient: "bottom", @@ -153,33 +167,48 @@ export function buildVegaSpec( ], marks: [ { - type: "area", + name: "group", + type: "group", from: { - data: "con", - }, - encode: { - update: { - interpolate: { value: "linear" }, - x: { - scale: "xscale", - field: "x", - }, - y: { - scale: "yscale", - field: "y", - }, - y2: { - scale: "yscale", - value: 0, - }, - fill: { - value: color, - }, - fillOpacity: { - value: 1, - }, + facet: { + name: "faceted_path_main", + data: "con", + groupby: ["name"], }, }, + marks: [ + { + name: "distribution_charts", + type: "area", + from: { + data: "faceted_path_main", + }, + encode: { + update: { + interpolate: { value: "linear" }, + x: { + scale: "xscale", + field: "x", + }, + y: { + scale: "yscale", + field: "y", + }, + y2: { + scale: "yscale", + value: 0, + }, + fill: { + field: "name", + scale: "color", + }, + fillOpacity: { + value: 1, + }, + }, + }, + }, + ], }, { type: "rect",