Merge pull request #504 from quantified-uncertainty/function-charts

Function charting
This commit is contained in:
Ozzie Gooen 2022-05-10 18:57:18 -04:00 committed by GitHub
commit 3cca106079
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 199 additions and 115 deletions

View File

@ -1,18 +1,24 @@
import * as React from "react";
import _ from "lodash";
import type { Spec } from "vega";
import type { Distribution, errorValue, result } from "@quri/squiggle-lang";
import {
Distribution,
result,
lambdaValue,
environment,
runForeign,
errorValueToString,
} from "@quri/squiggle-lang";
import { createClassFromSpec } from "react-vega";
import * as percentilesSpec from "../vega-specs/spec-percentiles.json";
import { DistributionChart } from "./DistributionChart";
import { NumberShower } from "./NumberShower";
import { ErrorBox } from "./ErrorBox";
let SquigglePercentilesChart = createClassFromSpec({
spec: percentilesSpec as Spec,
});
type distPlusFn = (a: number) => result<Distribution, errorValue>;
const _rangeByCount = (start: number, stop: number, count: number) => {
const step = (stop - start) / (count - 1);
const items = _.range(start, stop, step);
@ -27,38 +33,36 @@ function unwrap<a, b>(x: result<a, b>): a {
throw Error("FAILURE TO UNWRAP");
}
}
export type FunctionChartSettings = {
start: number;
stop: number;
count: number;
};
function mapFilter<a, b>(xs: a[], f: (x: a) => b | undefined): b[] {
let initial: b[] = [];
return xs.reduce((previous, current) => {
let value: b | undefined = f(current);
if (value !== undefined) {
return previous.concat([value]);
} else {
return previous;
}
}, initial);
interface FunctionChartProps {
fn: lambdaValue;
chartSettings: FunctionChartSettings;
environment: environment;
}
export const FunctionChart: React.FC<{
distPlusFn: distPlusFn;
diagramStart: number;
diagramStop: number;
diagramCount: number;
}> = ({ distPlusFn, diagramStart, diagramStop, diagramCount }) => {
export const FunctionChart: React.FC<FunctionChartProps> = ({
fn,
chartSettings,
environment,
}: FunctionChartProps) => {
let [mouseOverlay, setMouseOverlay] = React.useState(0);
function handleHover(...args) {
setMouseOverlay(args[1]);
function handleHover(_name: string, value: unknown) {
setMouseOverlay(value as number);
}
function handleOut() {
setMouseOverlay(NaN);
}
const signalListeners = { mousemove: handleHover, mouseout: handleOut };
let mouseItem = distPlusFn(mouseOverlay);
let mouseItem = runForeign(fn, [mouseOverlay], environment);
let showChart =
mouseItem.tag === "Ok" ? (
mouseItem.tag === "Ok" && mouseItem.value.tag == "distribution" ? (
<DistributionChart
distribution={mouseItem.value}
distribution={mouseItem.value.value}
width={400}
height={140}
showSummary={false}
@ -66,13 +70,49 @@ export const FunctionChart: React.FC<{
) : (
<></>
);
let data1 = _rangeByCount(diagramStart, diagramStop, diagramCount);
let valueData = mapFilter(data1, (x) => {
let result = distPlusFn(x);
let data1 = _rangeByCount(
chartSettings.start,
chartSettings.stop,
chartSettings.count
);
type point = { x: number; value: result<Distribution, string> };
let valueData: point[] = data1.map((x) => {
let result = runForeign(fn, [x], environment);
if (result.tag === "Ok") {
return { x: x, value: result.value };
if (result.value.tag == "distribution") {
return { x, value: { tag: "Ok", value: result.value.value } };
} else {
return {
x,
value: {
tag: "Error",
value:
"Cannot currently render functions that don't return distributions",
},
};
}
} else {
return {
x,
value: { tag: "Error", value: errorValueToString(result.value) },
};
}
}).map(({ x, value }) => {
});
let initialPartition: [
{ x: number; value: Distribution }[],
{ x: number; value: string }[]
] = [[], []];
let [functionImage, errors] = valueData.reduce((acc, current) => {
if (current.value.tag === "Ok") {
acc[0].push({ x: current.x, value: current.value.value });
} else {
acc[1].push({ x: current.x, value: current.value.value });
}
return acc;
}, initialPartition);
let percentiles = functionImage.map(({ x, value }) => {
return {
x: x,
p1: unwrap(value.inv(0.01)),
@ -91,24 +131,25 @@ export const FunctionChart: React.FC<{
};
});
let errorData = mapFilter(data1, (x) => {
let result = distPlusFn(x);
if (result.tag === "Error") {
return { x: x, error: result.value };
}
});
let error2 = _.groupBy(errorData, (x) => x.error);
let groupedErrors = _.groupBy(errors, (x) => x.value);
return (
<>
<SquigglePercentilesChart
data={{ facet: valueData }}
data={{ facet: percentiles }}
actions={false}
signalListeners={signalListeners}
/>
{showChart}
{_.keysIn(error2).map((k) => (
<ErrorBox heading={k}>
{`Values: [${error2[k].map((r) => r.x.toFixed(2)).join(",")}]`}
{_.entries(groupedErrors).map(([errorName, errorPoints]) => (
<ErrorBox heading={errorName}>
Values:{" "}
{errorPoints
.map((r) => <NumberShower number={r.x} />)
.reduce((a, b) => (
<>
{a}, {b}
</>
))}
</ErrorBox>
))}
</>

View File

@ -6,14 +6,16 @@ import {
errorValueToString,
squiggleExpression,
bindings,
samplingParams,
environment,
jsImports,
defaultImports,
defaultBindings,
defaultEnvironment,
} from "@quri/squiggle-lang";
import { NumberShower } from "./NumberShower";
import { DistributionChart } from "./DistributionChart";
import { ErrorBox } from "./ErrorBox";
import { FunctionChart, FunctionChartSettings } from "./FunctionChart";
const variableBox = {
Component: styled.div`
@ -36,7 +38,7 @@ const variableBox = {
interface VariableBoxProps {
heading: string;
children: React.ReactNode;
showTypes?: boolean;
showTypes: boolean;
}
export const VariableBox: React.FC<VariableBoxProps> = ({
@ -68,9 +70,13 @@ export interface SquiggleItemProps {
/** Whether to show a summary of statistics for distributions */
showSummary: boolean;
/** Whether to show type information */
showTypes?: boolean;
showTypes: boolean;
/** Whether to show users graph controls (scale etc) */
showControls?: boolean;
showControls: boolean;
/** Settings for displaying functions */
chartSettings: FunctionChartSettings;
/** Environment for further function executions */
environment: environment;
}
const SquiggleItem: React.FC<SquiggleItemProps> = ({
@ -80,6 +86,8 @@ const SquiggleItem: React.FC<SquiggleItemProps> = ({
showSummary,
showTypes = false,
showControls = false,
chartSettings,
environment,
}: SquiggleItemProps) => {
switch (expression.tag) {
case "number":
@ -147,6 +155,8 @@ const SquiggleItem: React.FC<SquiggleItemProps> = ({
height={50}
showTypes={showTypes}
showControls={showControls}
chartSettings={chartSettings}
environment={environment}
showSummary={showSummary}
/>
))}
@ -165,6 +175,8 @@ const SquiggleItem: React.FC<SquiggleItemProps> = ({
showTypes={showTypes}
showSummary={showSummary}
showControls={showControls}
chartSettings={chartSettings}
environment={environment}
/>
</>
))}
@ -178,9 +190,11 @@ const SquiggleItem: React.FC<SquiggleItemProps> = ({
);
case "lambda":
return (
<ErrorBox heading="No Viewer">
There is no viewer currently available for function types.
</ErrorBox>
<FunctionChart
fn={expression.value}
chartSettings={chartSettings}
environment={environment}
/>
);
}
};
@ -191,15 +205,9 @@ export interface SquiggleChartProps {
/** If the output requires monte carlo sampling, the amount of samples */
sampleCount?: number;
/** The amount of points returned to draw the distribution */
outputXYPoints?: number;
kernelWidth?: number;
pointDistLength?: number;
/** If the result is a function, where the function starts */
diagramStart?: number;
/** If the result is a function, where the function ends */
diagramStop?: number;
/** If the result is a function, how many points along the function it samples */
diagramCount?: number;
environment?: environment;
/** If the result is a function, where the function starts, ends and the amount of stops */
chartSettings?: FunctionChartSettings;
/** When the environment changes */
onChange?(expr: squiggleExpression): void;
/** CSS width of the element */
@ -223,10 +231,10 @@ const ChartWrapper = styled.div`
"Segoe UI Emoji", "Segoe UI Symbol", "Noto Color Emoji";
`;
let defaultChartSettings = { start: 0, stop: 10, count: 100 };
export const SquiggleChart: React.FC<SquiggleChartProps> = ({
squiggleString = "",
sampleCount = 1000,
outputXYPoints = 1000,
environment,
onChange = () => {},
height = 60,
bindings = defaultBindings,
@ -235,17 +243,10 @@ export const SquiggleChart: React.FC<SquiggleChartProps> = ({
width,
showTypes = false,
showControls = false,
chartSettings = defaultChartSettings,
}: SquiggleChartProps) => {
let samplingInputs: samplingParams = {
sampleCount: sampleCount,
xyPointLength: outputXYPoints,
};
let expressionResult = run(
squiggleString,
bindings,
samplingInputs,
jsImports
);
let expressionResult = run(squiggleString, bindings, environment, jsImports);
let e = environment ? environment : defaultEnvironment;
let internal: JSX.Element;
if (expressionResult.tag === "Ok") {
let expression = expressionResult.value;
@ -258,6 +259,8 @@ export const SquiggleChart: React.FC<SquiggleChartProps> = ({
showSummary={showSummary}
showTypes={showTypes}
showControls={showControls}
chartSettings={chartSettings}
environment={e}
/>
);
} else {

View File

@ -5,7 +5,7 @@ import { CodeEditor } from "./CodeEditor";
import styled from "styled-components";
import type {
squiggleExpression,
samplingParams,
environment,
bindings,
jsImports,
} from "@quri/squiggle-lang";
@ -21,11 +21,7 @@ export interface SquiggleEditorProps {
/** The input string for squiggle */
initialSquiggleString?: string;
/** If the output requires monte carlo sampling, the amount of samples */
sampleCount?: number;
/** The amount of points returned to draw the distribution */
outputXYPoints?: number;
kernelWidth?: number;
pointDistLength?: number;
environment?: environment;
/** If the result is a function, where the function starts */
diagramStart?: number;
/** If the result is a function, where the function ends */
@ -57,13 +53,10 @@ const Input = styled.div`
export let SquiggleEditor: React.FC<SquiggleEditorProps> = ({
initialSquiggleString = "",
width,
sampleCount,
outputXYPoints,
kernelWidth,
pointDistLength,
diagramStart,
diagramStop,
diagramCount,
environment,
diagramStart = 0,
diagramStop = 10,
diagramCount = 100,
onChange,
bindings = defaultBindings,
jsImports = defaultImports,
@ -72,6 +65,11 @@ export let SquiggleEditor: React.FC<SquiggleEditorProps> = ({
showSummary = false,
}: SquiggleEditorProps) => {
let [expression, setExpression] = React.useState(initialSquiggleString);
let chartSettings = {
start: diagramStart,
stop: diagramStop,
count: diagramCount,
};
return (
<div>
<Input>
@ -85,14 +83,9 @@ export let SquiggleEditor: React.FC<SquiggleEditorProps> = ({
</Input>
<SquiggleChart
width={width}
environment={environment}
squiggleString={expression}
sampleCount={sampleCount}
outputXYPoints={outputXYPoints}
kernelWidth={kernelWidth}
pointDistLength={pointDistLength}
diagramStart={diagramStart}
diagramStop={diagramStop}
diagramCount={diagramCount}
chartSettings={chartSettings}
onChange={onChange}
bindings={bindings}
jsImports={jsImports}
@ -140,11 +133,7 @@ export interface SquigglePartialProps {
/** The input string for squiggle */
initialSquiggleString?: string;
/** If the output requires monte carlo sampling, the amount of samples */
sampleCount?: number;
/** The amount of points returned to draw the distribution */
outputXYPoints?: number;
kernelWidth?: number;
pointDistLength?: number;
environment?: environment;
/** If the result is a function, where the function starts */
diagramStart?: number;
/** If the result is a function, where the function ends */
@ -165,14 +154,9 @@ export let SquigglePartial: React.FC<SquigglePartialProps> = ({
initialSquiggleString = "",
onChange,
bindings = defaultBindings,
sampleCount = 1000,
outputXYPoints = 1000,
environment,
jsImports = defaultImports,
}: SquigglePartialProps) => {
let samplingInputs: samplingParams = {
sampleCount: sampleCount,
xyPointLength: outputXYPoints,
};
let [expression, setExpression] = React.useState(initialSquiggleString);
let [error, setError] = React.useState<string | null>(null);
@ -180,7 +164,7 @@ export let SquigglePartial: React.FC<SquigglePartialProps> = ({
let squiggleResult = runPartial(
expression,
bindings,
samplingInputs,
environment,
jsImports
);
if (squiggleResult.tag == "Ok") {

View File

@ -4,6 +4,11 @@ import ReactDOM from "react-dom";
import { SquiggleChart } from "./SquiggleChart";
import CodeEditor from "./CodeEditor";
import styled from "styled-components";
import {
defaultBindings,
environment,
defaultImports,
} from "@quri/squiggle-lang";
interface FieldFloatProps {
label: string;
@ -96,6 +101,15 @@ let SquigglePlayground: FC<PlaygroundProps> = ({
let [diagramStart, setDiagramStart] = useState(0);
let [diagramStop, setDiagramStop] = useState(10);
let [diagramCount, setDiagramCount] = useState(20);
let chartSettings = {
start: diagramStart,
stop: diagramStop,
count: diagramCount,
};
let env: environment = {
sampleCount: sampleCount,
xyPointLength: outputXYPoints,
};
return (
<ShowBox height={height}>
<Row>
@ -112,15 +126,13 @@ let SquigglePlayground: FC<PlaygroundProps> = ({
<Display maxHeight={height - 3}>
<SquiggleChart
squiggleString={squiggleString}
sampleCount={sampleCount}
outputXYPoints={outputXYPoints}
diagramStart={diagramStart}
diagramStop={diagramStop}
diagramCount={diagramCount}
pointDistLength={pointDistLength}
environment={env}
chartSettings={chartSettings}
height={150}
showTypes={showTypes}
showControls={showControls}
bindings={defaultBindings}
jsImports={defaultImports}
showSummary={showSummary}
/>
</Display>

View File

@ -1,6 +1,5 @@
import * as _ from "lodash";
import {
samplingParams,
environment,
defaultEnvironment,
evaluatePartialUsingExternalBindings,
@ -8,6 +7,7 @@ import {
externalBindings,
expressionValue,
errorValue,
foreignFunctionInterface,
} from "../rescript/TypescriptInterface.gen";
export {
makeSampleSetDist,
@ -15,25 +15,31 @@ export {
distributionErrorToString,
distributionError,
} from "../rescript/TypescriptInterface.gen";
export type {
samplingParams,
errorValue,
externalBindings as bindings,
jsImports,
};
export type { errorValue, externalBindings as bindings, jsImports };
import {
jsValueToBinding,
jsValueToExpressionValue,
jsValue,
rescriptExport,
squiggleExpression,
convertRawToTypescript,
lambdaValue,
} from "./rescript_interop";
import { result, resultMap, tag, tagged } from "./types";
import { Distribution, shape } from "./distribution";
export { Distribution, squiggleExpression, result, resultMap, shape };
export {
Distribution,
squiggleExpression,
result,
resultMap,
shape,
lambdaValue,
environment,
defaultEnvironment,
};
export let defaultSamplingInputs: samplingParams = {
export let defaultSamplingInputs: environment = {
sampleCount: 10000,
xyPointLength: 10000,
};
@ -72,6 +78,20 @@ export function runPartial(
);
}
export function runForeign(
fn: lambdaValue,
args: jsValue[],
environment?: environment
): result<squiggleExpression, errorValue> {
let e = environment ? environment : defaultEnvironment;
let res: result<expressionValue, errorValue> = foreignFunctionInterface(
fn,
args.map(jsValueToExpressionValue),
e
);
return resultMap(res, (x) => createTsExport(x, e));
}
function mergeImportsWithBindings(
bindings: externalBindings,
imports: jsImports

View File

@ -1,5 +1,6 @@
import * as _ from "lodash";
import {
expressionValue,
mixedShape,
sampleSetDist,
genericDist,
@ -87,6 +88,8 @@ export type squiggleExpression =
| tagged<"number", number>
| tagged<"record", { [key: string]: squiggleExpression }>;
export { lambdaValue };
export function convertRawToTypescript(
result: rescriptExport,
environment: environment
@ -168,3 +171,21 @@ export function jsValueToBinding(value: jsValue): rescriptExport {
return { TAG: 7, _0: _.mapValues(value, jsValueToBinding) };
}
}
export function jsValueToExpressionValue(value: jsValue): expressionValue {
if (typeof value === "boolean") {
return { tag: "EvBool", value: value as boolean };
} else if (typeof value === "string") {
return { tag: "EvString", value: value as string };
} else if (typeof value === "number") {
return { tag: "EvNumber", value: value as number };
} else if (Array.isArray(value)) {
return { tag: "EvArray", value: value.map(jsValueToExpressionValue) };
} else {
// Record
return {
tag: "EvRecord",
value: _.mapValues(value, jsValueToExpressionValue),
};
}
}

View File

@ -84,3 +84,6 @@ type environment = ReducerInterface_ExpressionValue.environment
@genType
let defaultEnvironment = ReducerInterface_ExpressionValue.defaultEnvironment
@genType
let foreignFunctionInterface = Reducer.foreignFunctionInterface