squiggle/packages/components/src/SquiggleChart.tsx

347 lines
10 KiB
TypeScript
Raw Normal View History

2022-03-23 00:38:01 +00:00
import * as React from "react";
import _ from "lodash";
import type { Spec } from "vega";
import { run } from "@quri/squiggle-lang";
import type {
DistPlus,
SamplingInputs,
exportEnv,
exportDistribution,
} from "@quri/squiggle-lang";
import { createClassFromSpec } from "react-vega";
import * as chartSpecification from "./spec-distributions.json";
import * as percentilesSpec from "./spec-percentiles.json";
import { NumberShower } from "./NumberShower";
2022-04-07 11:28:33 +00:00
import styled from "styled-components";
2022-03-23 00:38:01 +00:00
let SquiggleVegaChart = createClassFromSpec({
2022-04-04 06:58:05 +00:00
spec: chartSpecification as Spec,
2022-03-23 00:38:01 +00:00
});
let SquigglePercentilesChart = createClassFromSpec({
2022-04-04 06:58:05 +00:00
spec: percentilesSpec as Spec,
2022-03-23 00:38:01 +00:00
});
2022-02-27 04:41:30 +00:00
export interface SquiggleChartProps {
/** The input string for squiggle */
squiggleString?: string;
2022-03-23 00:38:01 +00:00
/** If the output requires monte carlo sampling, the amount of samples */
2022-03-23 00:38:01 +00:00
sampleCount?: number;
/** The amount of points returned to draw the distribution */
2022-03-23 00:38:01 +00:00
outputXYPoints?: number;
kernelWidth?: number;
pointDistLength?: number;
/** If the result is a function, where the function starts */
2022-03-23 00:38:01 +00:00
diagramStart?: number;
/** If the result is a function, where the function ends */
2022-03-23 00:38:01 +00:00
diagramStop?: number;
/** If the result is a function, how many points along the function it samples */
2022-03-23 00:38:01 +00:00
diagramCount?: number;
2022-03-22 02:33:28 +00:00
/** variables declared before this expression */
2022-03-23 00:38:01 +00:00
environment?: exportEnv;
2022-03-22 02:33:28 +00:00
/** When the environment changes */
2022-03-23 00:38:01 +00:00
onEnvChange?(env: exportEnv): void;
2022-04-01 06:29:37 +00:00
/** CSS width of the element */
width?: number;
height?: number;
}
const Error = styled.div`
border: 1px solid #792e2e;
background: #eee2e2;
padding: 0.4em 0.8em;
`;
2022-04-07 11:28:33 +00:00
const ShowError: React.FC<{ heading: string; children: React.ReactNode }> = ({
heading = "Error",
children,
}) => {
return (
<Error>
<h3>{heading}</h3>
2022-04-07 11:28:33 +00:00
{children}
</Error>
);
2022-04-07 11:28:33 +00:00
};
export const SquiggleChart: React.FC<SquiggleChartProps> = ({
squiggleString = "",
sampleCount = 1000,
outputXYPoints = 1000,
kernelWidth,
pointDistLength = 1000,
diagramStart = 0,
diagramStop = 10,
diagramCount = 20,
environment = [],
2022-04-07 11:28:33 +00:00
onEnvChange = () => {},
width = 500,
height = 60,
}: SquiggleChartProps) => {
2022-03-23 00:38:01 +00:00
let samplingInputs: SamplingInputs = {
sampleCount: sampleCount,
outputXYPoints: outputXYPoints,
kernelWidth: kernelWidth,
pointDistLength: pointDistLength,
2022-03-23 00:38:01 +00:00
};
2022-02-27 04:41:30 +00:00
let result = run(squiggleString, samplingInputs, environment);
2022-02-27 04:41:30 +00:00
if (result.tag === "Ok") {
2022-03-23 00:38:01 +00:00
let environment = result.value.environment;
let exports = result.value.exports;
onEnvChange(environment);
2022-03-23 00:38:01 +00:00
let chartResults = exports.map((chartResult: exportDistribution) => {
if (chartResult["NAME"] === "Float") {
return <NumberShower precision={3} number={chartResult["VAL"]} />;
2022-03-23 00:38:01 +00:00
} else if (chartResult["NAME"] === "DistPlus") {
2022-02-27 04:41:30 +00:00
let shape = chartResult.VAL.pointSetDist;
2022-03-23 00:38:01 +00:00
if (shape.tag === "Continuous") {
2022-02-27 04:41:30 +00:00
let xyShape = shape.value.xyShape;
let totalY = xyShape.ys.reduce((a, b) => a + b);
let total = 0;
2022-03-23 00:38:01 +00:00
let cdf = xyShape.ys.map((y) => {
2022-02-27 04:41:30 +00:00
total += y;
return total / totalY;
2022-03-23 00:38:01 +00:00
});
let values = _.zip(cdf, xyShape.xs, xyShape.ys).map(([c, x, y]) => ({
cdf: (c * 100).toFixed(2) + "%",
x: x,
y: y,
}));
2022-04-04 06:58:05 +00:00
return (
<SquiggleVegaChart
width={width}
height={height}
2022-04-04 06:58:05 +00:00
data={{ con: values }}
actions={false}
/>
);
2022-03-23 00:38:01 +00:00
} else if (shape.tag === "Discrete") {
2022-02-27 04:41:30 +00:00
let xyShape = shape.value.xyShape;
let totalY = xyShape.ys.reduce((a, b) => a + b);
let total = 0;
2022-03-23 00:38:01 +00:00
let cdf = xyShape.ys.map((y) => {
2022-02-27 04:41:30 +00:00
total += y;
return total / totalY;
2022-03-23 00:38:01 +00:00
});
let values = _.zip(cdf, xyShape.xs, xyShape.ys).map(([c, x, y]) => ({
cdf: (c * 100).toFixed(2) + "%",
x: x,
y: y,
}));
2022-04-04 06:58:05 +00:00
return <SquiggleVegaChart data={{ dis: values }} actions={false} />;
2022-03-23 00:38:01 +00:00
} else if (shape.tag === "Mixed") {
2022-02-27 04:41:30 +00:00
let discreteShape = shape.value.discrete.xyShape;
let totalDiscrete = discreteShape.ys.reduce((a, b) => a + b);
let discretePoints = _.zip(discreteShape.xs, discreteShape.ys);
let continuousShape = shape.value.continuous.xyShape;
let continuousPoints = _.zip(continuousShape.xs, continuousShape.ys);
interface labeledPoint {
2022-03-23 00:38:01 +00:00
x: number;
y: number;
type: "discrete" | "continuous";
}
2022-02-27 04:41:30 +00:00
2022-03-23 00:38:01 +00:00
let markedDisPoints: labeledPoint[] = discretePoints.map(
([x, y]) => ({ x: x, y: y, type: "discrete" })
);
let markedConPoints: labeledPoint[] = continuousPoints.map(
([x, y]) => ({ x: x, y: y, type: "continuous" })
);
2022-02-27 04:41:30 +00:00
2022-03-23 00:38:01 +00:00
let sortedPoints = _.sortBy(
markedDisPoints.concat(markedConPoints),
"x"
);
2022-02-27 04:41:30 +00:00
let totalContinuous = 1 - totalDiscrete;
2022-03-23 00:38:01 +00:00
let totalY = continuousShape.ys.reduce(
(a: number, b: number) => a + b
);
2022-02-27 04:41:30 +00:00
let total = 0;
let cdf = sortedPoints.map((point: labeledPoint) => {
2022-04-04 06:58:05 +00:00
if (point.type === "discrete") {
2022-02-27 04:41:30 +00:00
total += point.y;
return total;
2022-04-04 06:58:05 +00:00
} else if (point.type === "continuous") {
2022-03-23 00:38:01 +00:00
total += (point.y / totalY) * totalContinuous;
2022-02-27 04:41:30 +00:00
return total;
}
});
interface cdfLabeledPoint {
2022-03-23 00:38:01 +00:00
cdf: string;
x: number;
y: number;
type: "discrete" | "continuous";
2022-02-27 04:41:30 +00:00
}
2022-03-23 00:38:01 +00:00
let cdfLabeledPoint: cdfLabeledPoint[] = _.zipWith(
cdf,
sortedPoints,
(c: number, point: labeledPoint) => ({
...point,
cdf: (c * 100).toFixed(2) + "%",
})
);
let continuousValues = cdfLabeledPoint.filter(
2022-04-04 06:58:05 +00:00
(x) => x.type === "continuous"
2022-03-23 00:38:01 +00:00
);
let discreteValues = cdfLabeledPoint.filter(
2022-04-04 06:58:05 +00:00
(x) => x.type === "discrete"
2022-03-23 00:38:01 +00:00
);
2022-02-27 04:41:30 +00:00
return (
2022-03-23 00:38:01 +00:00
<SquiggleVegaChart
data={{ con: continuousValues, dis: discreteValues }}
actions={false}
2022-03-23 00:38:01 +00:00
/>
2022-02-27 04:41:30 +00:00
);
2022-03-23 00:38:01 +00:00
}
} else if (chartResult.NAME === "Function") {
// We are looking at a function. In this case, we draw a Percentiles chart
let start = diagramStart;
let stop = diagramStop;
let count = diagramCount;
2022-03-23 00:38:01 +00:00
let step = (stop - start) / count;
let data = _.range(start, stop, step).map((x) => {
2022-04-04 06:58:05 +00:00
if (chartResult.NAME === "Function") {
2022-03-23 00:38:01 +00:00
let result = chartResult.VAL(x);
2022-04-04 06:58:05 +00:00
if (result.tag === "Ok") {
2022-03-23 00:38:01 +00:00
let percentileArray = [
0.01, 0.05, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 0.95,
0.99,
];
let percentiles = getPercentiles(percentileArray, result.value);
return {
x: x,
p1: percentiles[0],
p5: percentiles[1],
p10: percentiles[2],
p20: percentiles[3],
p30: percentiles[4],
p40: percentiles[5],
p50: percentiles[6],
p60: percentiles[7],
p70: percentiles[8],
p80: percentiles[9],
p90: percentiles[10],
p95: percentiles[11],
p99: percentiles[12],
};
}
2022-03-25 04:35:32 +00:00
return null;
2022-02-27 04:41:30 +00:00
}
2022-03-23 00:38:01 +00:00
});
2022-04-04 06:58:05 +00:00
return (
<SquigglePercentilesChart
data={{ facet: data.filter((x) => x !== null) }}
actions={false}
2022-04-04 06:58:05 +00:00
/>
);
2022-02-27 04:41:30 +00:00
}
2022-03-23 00:38:01 +00:00
});
2022-02-27 04:41:30 +00:00
return <>{chartResults}</>;
2022-04-04 06:58:05 +00:00
} else if (result.tag === "Error") {
2022-02-27 04:41:30 +00:00
// At this point, we came across an error. What was our error?
2022-04-07 11:28:33 +00:00
return (
<ShowError heading={"Parse Error"}>
{result.value}
</ShowError>
);
2022-02-27 04:41:30 +00:00
}
2022-03-23 00:38:01 +00:00
return <p>{"Invalid Response"}</p>;
2022-02-27 04:41:30 +00:00
};
2022-03-23 00:38:01 +00:00
function getPercentiles(percentiles: number[], t: DistPlus) {
2022-04-04 06:58:05 +00:00
if (t.pointSetDist.tag === "Discrete") {
2022-02-27 04:41:30 +00:00
let total = 0;
2022-03-23 00:38:01 +00:00
let maxX = _.max(t.pointSetDist.value.xyShape.xs);
let bounds = percentiles.map((_) => maxX);
_.zipWith(
t.pointSetDist.value.xyShape.xs,
t.pointSetDist.value.xyShape.ys,
(x, y) => {
total += y;
2022-02-27 04:41:30 +00:00
percentiles.forEach((v, i) => {
2022-04-04 06:58:05 +00:00
if (total > v && bounds[i] === maxX) {
2022-03-23 00:38:01 +00:00
bounds[i] = x;
}
});
}
);
return bounds;
2022-04-04 06:58:05 +00:00
} else if (t.pointSetDist.tag === "Continuous") {
2022-02-27 04:41:30 +00:00
let total = 0;
2022-03-23 00:38:01 +00:00
let maxX = _.max(t.pointSetDist.value.xyShape.xs);
let totalY = _.sum(t.pointSetDist.value.xyShape.ys);
let bounds = percentiles.map((_) => maxX);
_.zipWith(
t.pointSetDist.value.xyShape.xs,
t.pointSetDist.value.xyShape.ys,
(x, y) => {
2022-02-27 04:41:30 +00:00
total += y / totalY;
percentiles.forEach((v, i) => {
2022-04-04 06:58:05 +00:00
if (total > v && bounds[i] === maxX) {
2022-03-23 00:38:01 +00:00
bounds[i] = x;
}
});
}
);
return bounds;
2022-04-04 06:58:05 +00:00
} else if (t.pointSetDist.tag === "Mixed") {
2022-02-27 04:41:30 +00:00
let discreteShape = t.pointSetDist.value.discrete.xyShape;
let totalDiscrete = discreteShape.ys.reduce((a, b) => a + b);
let discretePoints = _.zip(discreteShape.xs, discreteShape.ys);
let continuousShape = t.pointSetDist.value.continuous.xyShape;
let continuousPoints = _.zip(continuousShape.xs, continuousShape.ys);
interface labeledPoint {
2022-03-23 00:38:01 +00:00
x: number;
y: number;
type: "discrete" | "continuous";
}
2022-02-27 04:41:30 +00:00
2022-03-23 00:38:01 +00:00
let markedDisPoints: labeledPoint[] = discretePoints.map(([x, y]) => ({
x: x,
y: y,
type: "discrete",
}));
let markedConPoints: labeledPoint[] = continuousPoints.map(([x, y]) => ({
x: x,
y: y,
type: "continuous",
}));
2022-02-27 04:41:30 +00:00
2022-03-23 00:38:01 +00:00
let sortedPoints = _.sortBy(markedDisPoints.concat(markedConPoints), "x");
2022-02-27 04:41:30 +00:00
let totalContinuous = 1 - totalDiscrete;
2022-03-23 00:38:01 +00:00
let totalY = continuousShape.ys.reduce((a: number, b: number) => a + b);
2022-02-27 04:41:30 +00:00
let total = 0;
2022-03-23 00:38:01 +00:00
let maxX = _.max(sortedPoints.map((x) => x.x));
let bounds = percentiles.map((_) => maxX);
2022-02-27 04:41:30 +00:00
sortedPoints.map((point: labeledPoint) => {
2022-04-04 06:58:05 +00:00
if (point.type === "discrete") {
2022-02-27 04:41:30 +00:00
total += point.y;
2022-04-04 06:58:05 +00:00
} else if (point.type === "continuous") {
2022-03-23 00:38:01 +00:00
total += (point.y / totalY) * totalContinuous;
2022-02-27 04:41:30 +00:00
}
2022-03-23 00:38:01 +00:00
percentiles.forEach((v, i) => {
2022-04-04 06:58:05 +00:00
if (total > v && bounds[i] === maxX) {
2022-02-27 04:41:30 +00:00
bounds[i] = total;
}
2022-03-23 00:38:01 +00:00
});
2022-02-27 04:41:30 +00:00
return total;
});
return bounds;
}
2022-04-07 11:28:33 +00:00
}