First commit: better late than never
This commit is contained in:
commit
a0a8093e12
38
calibration.py
Normal file
38
calibration.py
Normal file
|
@ -0,0 +1,38 @@
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
# This function is a sklearn.calibration.calibration_curve modification
|
||||||
|
def calibration_curve(y_true, y_prob, *, n_bins=5, strategy="uniform"):
|
||||||
|
y_true = np.array(y_true)
|
||||||
|
y_prob = np.array(y_prob)
|
||||||
|
|
||||||
|
if strategy == "quantile": # Determine bin edges by distribution of data
|
||||||
|
quantiles = np.linspace(0, 1, n_bins + 1)
|
||||||
|
bins = np.percentile(y_prob, quantiles * 100)
|
||||||
|
bins[-1] = bins[-1] + 1e-8
|
||||||
|
|
||||||
|
elif strategy == "uniform":
|
||||||
|
bins = np.linspace(0.0, 1.0 + 1e-8, n_bins + 1)
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
"Invalid entry to 'strategy' input. Strategy "
|
||||||
|
"must be either 'quantile' or 'uniform'."
|
||||||
|
)
|
||||||
|
|
||||||
|
binids = np.digitize(y_prob, bins) - 1
|
||||||
|
|
||||||
|
bin_sums = np.bincount(binids, weights=y_prob, minlength=len(bins))
|
||||||
|
bin_true = np.bincount(binids, weights=y_true, minlength=len(bins))
|
||||||
|
bin_total = np.bincount(binids, minlength=len(bins))
|
||||||
|
|
||||||
|
nonzero = bin_total != 0
|
||||||
|
prob_true = bin_true[nonzero] / bin_total[nonzero]
|
||||||
|
prob_pred = bin_sums[nonzero] / bin_total[nonzero]
|
||||||
|
|
||||||
|
return prob_true, prob_pred, bin_total[nonzero]
|
||||||
|
|
||||||
|
|
||||||
|
def overconfidence(y_true, y_pred):
|
||||||
|
x = y_pred * y_true + (1 - y_pred) * (1 - y_true)
|
||||||
|
return np.mean((x - 1) * (x - 0.5)) / np.mean((x - 0.5) * (x - 0.5))
|
43
firebase_requests.py
Normal file
43
firebase_requests.py
Normal file
|
@ -0,0 +1,43 @@
|
||||||
|
import json
|
||||||
|
|
||||||
|
from gjo_requests import request_forecasts, request_resolutions
|
||||||
|
from google.cloud import firestore
|
||||||
|
|
||||||
|
firestore_info = json.loads(st.secrets["firestore_info"])
|
||||||
|
credentials = service_account.Credentials.from_service_account_info(firestore_info)
|
||||||
|
db = firestore.Client(credentials=credentials, project="gjo-calibration")
|
||||||
|
|
||||||
|
|
||||||
|
def get_forecasts(uid, questions, platform_url, headers, cookies):
|
||||||
|
db_forecasts = db.collection("users").document(uid).get().to_dict()
|
||||||
|
db_forecasts = dict() if db_forecasts is None else db_forecasts
|
||||||
|
|
||||||
|
missing_forecasts_qs = list(set(questions) - set(db_forecasts))
|
||||||
|
missing_forecasts = request_forecasts(
|
||||||
|
uid, missing_forecasts_qs, platform_url, headers, cookies
|
||||||
|
)
|
||||||
|
|
||||||
|
if missing_forecasts:
|
||||||
|
if not db_forecasts:
|
||||||
|
db.collection("users").add({}, uid)
|
||||||
|
db.collection("users").document(uid).update(missing_forecasts)
|
||||||
|
|
||||||
|
return {**db_forecasts, **missing_forecasts}
|
||||||
|
|
||||||
|
|
||||||
|
def get_resolutions(questions, platform_url, headers, cookies):
|
||||||
|
db_resolutions = db.collection("questions").document("resolutions").get().to_dict()
|
||||||
|
db_resolutions = dict() if db_resolutions is None else db_resolutions
|
||||||
|
relevant_resolutions = {
|
||||||
|
key: value for key, value in db_resolutions.items() if key in set(questions)
|
||||||
|
}
|
||||||
|
|
||||||
|
missing_resolutions_qs = list(set(questions) - set(relevant_resolutions))
|
||||||
|
missing_resolutions = request_resolutions(
|
||||||
|
missing_resolutions_qs, platform_url, headers, cookies
|
||||||
|
)
|
||||||
|
|
||||||
|
if missing_resolutions:
|
||||||
|
db.collection("questions").document("resolutions").update(missing_resolutions)
|
||||||
|
|
||||||
|
return {**relevant_resolutions, **missing_resolutions}
|
175
gjo_requests.py
Normal file
175
gjo_requests.py
Normal file
|
@ -0,0 +1,175 @@
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
import re
|
||||||
|
from itertools import count
|
||||||
|
|
||||||
|
import aiohttp
|
||||||
|
import aioitertools
|
||||||
|
import requests
|
||||||
|
import streamlit as st
|
||||||
|
from bs4 import BeautifulSoup
|
||||||
|
|
||||||
|
loop = asyncio.new_event_loop()
|
||||||
|
asyncio.set_event_loop(loop)
|
||||||
|
|
||||||
|
|
||||||
|
@st.cache
|
||||||
|
def get_resolved_questions(uid, platform_url, headers, cookies):
|
||||||
|
logging.info(
|
||||||
|
f"[ ] get_resolved_questions for uid={uid}, platform_url={platform_url}"
|
||||||
|
)
|
||||||
|
|
||||||
|
questions = [] # [question_id]
|
||||||
|
|
||||||
|
for page_num in count(1):
|
||||||
|
url = f"{platform_url}/memberships/{uid}/scores/?page={page_num}"
|
||||||
|
page = requests.get(url, headers=headers, cookies=cookies).text
|
||||||
|
|
||||||
|
extracted_qs = re.findall("/questions/(\d+)", page)
|
||||||
|
questions.extend(extracted_qs)
|
||||||
|
|
||||||
|
if not extracted_qs:
|
||||||
|
break
|
||||||
|
|
||||||
|
logging.info(
|
||||||
|
f"[X] get_resolved_questions for uid={uid}, platform_url={platform_url}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return questions
|
||||||
|
|
||||||
|
|
||||||
|
async def get_question_resolution(qid, platform_url, session):
|
||||||
|
logging.info(
|
||||||
|
f"[ ] get_question_resolution for uid={uid}, platform_url={platform_url}"
|
||||||
|
)
|
||||||
|
|
||||||
|
url = f"{platform_url}/questions/{qid}"
|
||||||
|
|
||||||
|
async with session.get(url) as resp:
|
||||||
|
if resp.status != 200:
|
||||||
|
logging.error(
|
||||||
|
f"get_question_resolution for uid={uid}, platform_url={platform_url} | "
|
||||||
|
f"resp.status == {resp.status} → {resp.reason}"
|
||||||
|
)
|
||||||
|
|
||||||
|
page = await resp.text()
|
||||||
|
|
||||||
|
soup = BeautifulSoup(page, "html.parser")
|
||||||
|
soup = soup.find_all("div", {"id": "prediction-interface-container"})[0]
|
||||||
|
|
||||||
|
binary = soup.find_all("div", {"class": "binary-probability-value"})
|
||||||
|
if binary:
|
||||||
|
y_true = (0, 1) if re.search("Yes", binary[1].text) is None else (1, 0)
|
||||||
|
else:
|
||||||
|
tables = soup.find_all("table")
|
||||||
|
y_true = tuple(len(tr.findAll("i")) for tr in tables[0].findAll("tr")[1:])
|
||||||
|
|
||||||
|
logging.info(
|
||||||
|
f"[X] get_question_resolution for uid={uid}, platform_url={platform_url}"
|
||||||
|
)
|
||||||
|
return {"y_true": y_true}
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_forecasts_from_page(page):
|
||||||
|
soup = BeautifulSoup(page, "html.parser")
|
||||||
|
soup_predictions = soup.find_all("div", {"class": "prediction-values"})
|
||||||
|
predictions = [re.findall("\n\s*(\d+)%", p_tag.text) for p_tag in soup_predictions]
|
||||||
|
predictions = [tuple(int(prob) / 100 for prob in pred) for pred in predictions]
|
||||||
|
predictions = [
|
||||||
|
(pred[0], 1 - pred[0]) if len(pred) == 1 else pred for pred in predictions
|
||||||
|
]
|
||||||
|
|
||||||
|
# I search for a line containing "made a forecast"
|
||||||
|
# I search for the next line containig <span data-localizable-timestamp="[^"]*">
|
||||||
|
# And graab a timestamp from it
|
||||||
|
timestamps = []
|
||||||
|
looking_for_a_forecast = True
|
||||||
|
for line in page.split("\n"):
|
||||||
|
if looking_for_a_forecast:
|
||||||
|
hit = re.findall("made a forecast", line)
|
||||||
|
if hit:
|
||||||
|
looking_for_a_forecast = False
|
||||||
|
|
||||||
|
else:
|
||||||
|
hit = re.findall('<span data-localizable-timestamp="([^"]+)">', line)
|
||||||
|
if hit:
|
||||||
|
timestamps.extend(hit)
|
||||||
|
looking_for_a_forecast = True
|
||||||
|
|
||||||
|
if len(timestamps) != len(predictions):
|
||||||
|
logging.error(
|
||||||
|
f"In _extract_forecasts_from_page with uid={uid}, qid={qid}, page_num={page_num} "
|
||||||
|
f"got different number of predictions ({len(timestamps)}) and timestamps ({len(predictions)})."
|
||||||
|
)
|
||||||
|
|
||||||
|
return [
|
||||||
|
{"y_pred": pred, "timestamp": timestamp}
|
||||||
|
for pred, timestamp in zip(predictions, timestamps)
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
async def get_forecasts_on_the_question(uid, qid, platform_url, session):
|
||||||
|
logging.info(
|
||||||
|
f"[ ] get_forecasts_on_the_question for uid={uid}, qid={qid}, platform_url={platform_url}"
|
||||||
|
)
|
||||||
|
|
||||||
|
forecasts = [] # [{"y_pred": (probs, ...), "timestamp": timestamp}, ...]
|
||||||
|
|
||||||
|
for page_num in count(1):
|
||||||
|
url = f"{platform_url}/questions/{qid}/prediction_sets?membership_id={uid}&page={page_num}"
|
||||||
|
|
||||||
|
async with session.get(url) as resp:
|
||||||
|
if resp.status != 200:
|
||||||
|
logging.error(
|
||||||
|
f"get_forecasts_on_the_question for uid={uid}, qid={qid}, platform_url={platform_url} | "
|
||||||
|
f"resp.status == {resp.status} → {resp.reason}"
|
||||||
|
)
|
||||||
|
|
||||||
|
page = await resp.text()
|
||||||
|
|
||||||
|
extracted_forecasts = _extract_forecasts_from_page(page)
|
||||||
|
forecasts.extend(extracted_forecasts)
|
||||||
|
|
||||||
|
if not extracted_forecasts:
|
||||||
|
break
|
||||||
|
|
||||||
|
logging.info(
|
||||||
|
f"[X] get_forecasts_on_the_question for uid={uid}, qid={qid}, platform_url={platform_url}"
|
||||||
|
)
|
||||||
|
return forecasts
|
||||||
|
|
||||||
|
|
||||||
|
# ---
|
||||||
|
|
||||||
|
|
||||||
|
async def async_get_forecasts(uid, questions, platform_url, headers, cookies):
|
||||||
|
async with aiohttp.ClientSession(headers=headers, cookies=cookies) as session:
|
||||||
|
forecasts_list = await aioitertools.asyncio.gather(
|
||||||
|
*[
|
||||||
|
get_forecasts_on_the_question(uid, q, platform_url, session)
|
||||||
|
for q in questions
|
||||||
|
],
|
||||||
|
limit=5,
|
||||||
|
)
|
||||||
|
return {q: forecasts_list[i] for i, q in enumerate(questions)}
|
||||||
|
|
||||||
|
|
||||||
|
async def async_get_resolutions(questions, platform_url, headers, cookies):
|
||||||
|
async with aiohttp.ClientSession(headers=headers, cookies=cookies) as session:
|
||||||
|
resolutions_list = await aioitertools.asyncio.gather(
|
||||||
|
*[get_question_resolution(q, platform_url, session) for q in questions],
|
||||||
|
limit=5,
|
||||||
|
)
|
||||||
|
return {q: resolutions_list[i] for i, q in enumerate(questions)}
|
||||||
|
|
||||||
|
|
||||||
|
def request_forecasts(uid, missing_forecasts_qs, platform_url, headers, cookies):
|
||||||
|
return asyncio.run(
|
||||||
|
async_get_forecasts(uid, missing_forecasts_qs, platform_url, headers, cookies)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def request_resolutions(missing_resolutions_qs, platform_url, headers, cookies):
|
||||||
|
return asyncio.run(
|
||||||
|
async_get_resolutions(missing_resolutions_qs, platform_url, headers, cookies)
|
||||||
|
)
|
172
plotting.py
Normal file
172
plotting.py
Normal file
|
@ -0,0 +1,172 @@
|
||||||
|
import numpy as np
|
||||||
|
import plotly.graph_objects as go
|
||||||
|
from calibration import calibration_curve
|
||||||
|
|
||||||
|
|
||||||
|
def plotly_calibration(y_true, y_pred, n_bins, strategy="quantile"):
|
||||||
|
fraction_of_positives, mean_predicted_value, counts = calibration_curve(
|
||||||
|
y_true, y_pred, n_bins=n_bins, strategy=strategy
|
||||||
|
)
|
||||||
|
error_y = np.sqrt((fraction_of_positives) * (1 - fraction_of_positives) / counts)
|
||||||
|
|
||||||
|
fig = go.Figure()
|
||||||
|
|
||||||
|
fig.add_trace(
|
||||||
|
go.Scatter(
|
||||||
|
x=mean_predicted_value,
|
||||||
|
y=fraction_of_positives,
|
||||||
|
customdata=counts,
|
||||||
|
mode="markers",
|
||||||
|
error_y=dict(
|
||||||
|
type="data",
|
||||||
|
array=error_y,
|
||||||
|
thickness=1.5,
|
||||||
|
width=3,
|
||||||
|
),
|
||||||
|
hovertemplate="<br>".join(
|
||||||
|
[
|
||||||
|
"x: %{x:.3f}",
|
||||||
|
"y: %{y:.3f}",
|
||||||
|
"N: %{customdata}",
|
||||||
|
"<extra></extra>",
|
||||||
|
]
|
||||||
|
),
|
||||||
|
showlegend=False,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
fig.add_shape(
|
||||||
|
type="line",
|
||||||
|
x0=0,
|
||||||
|
y0=0,
|
||||||
|
x1=1,
|
||||||
|
y1=1,
|
||||||
|
line=dict(
|
||||||
|
color="LightSeaGreen",
|
||||||
|
width=2,
|
||||||
|
dash="dot",
|
||||||
|
),
|
||||||
|
opacity=0.5,
|
||||||
|
)
|
||||||
|
|
||||||
|
fig.update_layout(
|
||||||
|
width=800,
|
||||||
|
height=800,
|
||||||
|
title="Calibration plot",
|
||||||
|
xaxis_title="Mean predicted value",
|
||||||
|
yaxis_title="Fraction of positives (± std)",
|
||||||
|
)
|
||||||
|
|
||||||
|
fig.update_xaxes(
|
||||||
|
range=[-0.05, 1.05],
|
||||||
|
constrain="domain",
|
||||||
|
)
|
||||||
|
|
||||||
|
fig.update_yaxes(
|
||||||
|
range=[-0.05, 1.05],
|
||||||
|
constrain="domain",
|
||||||
|
scaleanchor="x",
|
||||||
|
scaleratio=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
return fig
|
||||||
|
|
||||||
|
|
||||||
|
def plotly_calibration_odds(y_true, y_pred, n_bins, strategy="quantile"):
|
||||||
|
y_pred = np.clip(y_pred, 0.005, 0.995) # clipping to avoid undefined odds
|
||||||
|
y_true = np.clip(y_true, 1e-3, 1 - 1e-3)
|
||||||
|
fraction_of_positives, mean_predicted_value, counts = calibration_curve(
|
||||||
|
y_true, y_pred, n_bins=n_bins, strategy=strategy
|
||||||
|
)
|
||||||
|
error_y = np.sqrt((fraction_of_positives) * (1 - fraction_of_positives) / counts)
|
||||||
|
|
||||||
|
fig = go.Figure()
|
||||||
|
|
||||||
|
transform = lambda x: np.log2(1 / (1 - x) - 1) # 66.6% → 2^{1}:1 → 1
|
||||||
|
|
||||||
|
customdata = np.dstack(
|
||||||
|
[
|
||||||
|
counts,
|
||||||
|
[
|
||||||
|
f"{2**x:.1f} : 1" if x > 0 else f"1 : {2**-x:.1f}"
|
||||||
|
for x in transform(mean_predicted_value)
|
||||||
|
],
|
||||||
|
[
|
||||||
|
f"{2**x:.1f} : 1" if x > 0 else f"1 : {2**-x:.1f}"
|
||||||
|
for x in transform(fraction_of_positives)
|
||||||
|
],
|
||||||
|
]
|
||||||
|
).squeeze()
|
||||||
|
|
||||||
|
fig.add_trace(
|
||||||
|
go.Scatter(
|
||||||
|
x=transform(mean_predicted_value),
|
||||||
|
y=transform(fraction_of_positives),
|
||||||
|
customdata=customdata,
|
||||||
|
mode="markers",
|
||||||
|
error_y=dict(
|
||||||
|
type="data",
|
||||||
|
symmetric=False,
|
||||||
|
array=transform(fraction_of_positives + error_y)
|
||||||
|
- transform(fraction_of_positives),
|
||||||
|
arrayminus=transform(fraction_of_positives)
|
||||||
|
- transform(fraction_of_positives - error_y),
|
||||||
|
thickness=1.5,
|
||||||
|
width=3,
|
||||||
|
),
|
||||||
|
hovertemplate="<br>".join(
|
||||||
|
[
|
||||||
|
"x: %{customdata[1]}",
|
||||||
|
"y: %{customdata[2]}",
|
||||||
|
"N: %{customdata[0]}",
|
||||||
|
"<extra></extra>",
|
||||||
|
]
|
||||||
|
),
|
||||||
|
showlegend=False,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
fig.add_shape(
|
||||||
|
type="line",
|
||||||
|
x0=-8,
|
||||||
|
y0=-8,
|
||||||
|
x1=8,
|
||||||
|
y1=8,
|
||||||
|
line=dict(
|
||||||
|
color="LightSeaGreen",
|
||||||
|
width=2,
|
||||||
|
dash="dot",
|
||||||
|
),
|
||||||
|
opacity=0.5,
|
||||||
|
)
|
||||||
|
|
||||||
|
fig.update_layout(
|
||||||
|
width=800,
|
||||||
|
height=800,
|
||||||
|
title="Calibration plot in terms of odds",
|
||||||
|
xaxis_title="Mean predicted value",
|
||||||
|
yaxis_title="Fraction of positives (± std)",
|
||||||
|
)
|
||||||
|
|
||||||
|
fig.update_xaxes(
|
||||||
|
range=[-8, 8],
|
||||||
|
constrain="domain",
|
||||||
|
tickmode="array",
|
||||||
|
tickvals=list(range(-10, 10)),
|
||||||
|
ticktext=[
|
||||||
|
f"{2**x} : 1" if x > 0 else f"1 : {2**-x}" for x in list(range(-10, 10))
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
fig.update_yaxes(
|
||||||
|
range=[-8, 8],
|
||||||
|
constrain="domain",
|
||||||
|
scaleanchor="x",
|
||||||
|
scaleratio=1,
|
||||||
|
tickvals=list(range(-10, 10)),
|
||||||
|
ticktext=[
|
||||||
|
f"{2**x} : 1" if x > 0 else f"1 : {2**-x}" for x in list(range(-10, 10))
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
return fig
|
24
requirements.txt
Normal file
24
requirements.txt
Normal file
|
@ -0,0 +1,24 @@
|
||||||
|
lxml==4.6.1
|
||||||
|
lockfile==0.12.2
|
||||||
|
numpy==1.19.3
|
||||||
|
keyring==21.5.0
|
||||||
|
mypy_extensions==0.4.3
|
||||||
|
pandas==1.1.3
|
||||||
|
typing_extensions==3.7.4.3
|
||||||
|
aiohttp==3.7.4.post0
|
||||||
|
aioitertools==0.7.1
|
||||||
|
beautifulsoup4==4.9.3
|
||||||
|
brotli==1.0.9
|
||||||
|
cryptography==3.4.7
|
||||||
|
Cython==0.29.23
|
||||||
|
docutils==0.17.1
|
||||||
|
importlib_metadata==4.4.0
|
||||||
|
ipaddr==2.2.0
|
||||||
|
ordereddict==1.1
|
||||||
|
plotly==4.14.3
|
||||||
|
protobuf==3.17.1
|
||||||
|
pyOpenSSL==20.0.1
|
||||||
|
streamlit==0.82.0
|
||||||
|
uncurl==0.0.11
|
||||||
|
wincertstore==0.2
|
||||||
|
zipp==3.4.1
|
100
strmlt.py
Normal file
100
strmlt.py
Normal file
|
@ -0,0 +1,100 @@
|
||||||
|
import numpy as np
|
||||||
|
import pandas as pd
|
||||||
|
import streamlit as st
|
||||||
|
import uncurl
|
||||||
|
from calibration import overconfidence
|
||||||
|
from firebase_requests import get_forecasts, get_resolutions
|
||||||
|
from gjo_requests import get_resolved_questions
|
||||||
|
from plotting import plotly_calibration, plotly_calibration_odds
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
|
||||||
|
st.title("Learn how calibrated are you?")
|
||||||
|
|
||||||
|
# ---
|
||||||
|
|
||||||
|
# if st.checkbox('I am new! Show me instructions.'):
|
||||||
|
# st.write("""
|
||||||
|
# Hey!
|
||||||
|
# """)
|
||||||
|
|
||||||
|
# ---
|
||||||
|
|
||||||
|
platform = st.selectbox(
|
||||||
|
"Which platform are you using?",
|
||||||
|
["Good Judgement Open", "CSET Foretell"],
|
||||||
|
)
|
||||||
|
platform_url = {
|
||||||
|
"Good Judgement Open": "https://www.gjopen.com",
|
||||||
|
"CSET Foretell": "https://www.cset-foretell.com",
|
||||||
|
}[platform]
|
||||||
|
|
||||||
|
uid = st.number_input("What is your user ID?", min_value=1, value=28899)
|
||||||
|
uid = str(uid)
|
||||||
|
|
||||||
|
curl_value = ""
|
||||||
|
|
||||||
|
curl_command = st.text_area(
|
||||||
|
"Ugh... Gimme your cURL info...", value=curl_value.strip()
|
||||||
|
)
|
||||||
|
curl_content = uncurl.parse_context(curl_command)
|
||||||
|
headers, cookies = curl_content.headers, curl_content.cookies
|
||||||
|
|
||||||
|
# ---
|
||||||
|
|
||||||
|
questions = get_resolved_questions(uid, platform_url, headers, cookies)
|
||||||
|
|
||||||
|
st.write(f"{len(questions)} questions you forecasted on have resolved.")
|
||||||
|
|
||||||
|
# ---
|
||||||
|
# TODO: Make a progress bar..?
|
||||||
|
|
||||||
|
forecasts = get_forecasts(uid, questions, platform_url, headers, cookies)
|
||||||
|
resolutions = get_resolutions(questions, platform_url, headers, cookies)
|
||||||
|
|
||||||
|
# ---
|
||||||
|
|
||||||
|
num_forecasts = sum(len(f) for f in forecasts.values())
|
||||||
|
st.write(
|
||||||
|
f"On these {len(questions)} questions you've made {num_forecasts} forecasts."
|
||||||
|
)
|
||||||
|
|
||||||
|
flatten = lambda t: [item for sublist in t for item in sublist]
|
||||||
|
y_true = flatten(resolutions[q]["y_true"] for q in questions for _ in forecasts[q])
|
||||||
|
y_pred = flatten(f["y_pred"] for q in questions for f in forecasts[q])
|
||||||
|
|
||||||
|
# Note that I am "double counting" each prediction.
|
||||||
|
if st.checkbox("Drop last"):
|
||||||
|
y_true = flatten(
|
||||||
|
resolutions[q]["y_true"][:-1] for q in questions for _ in forecasts[q]
|
||||||
|
)
|
||||||
|
y_pred = flatten(f["y_pred"][:-1] for q in questions for f in forecasts[q])
|
||||||
|
|
||||||
|
y_true, y_pred = np.array(y_true), np.array(y_pred)
|
||||||
|
|
||||||
|
st.write(f"Which gives us {len(y_pred)} datapoints to work with.")
|
||||||
|
|
||||||
|
# ---
|
||||||
|
|
||||||
|
strategy = st.selectbox(
|
||||||
|
"Which binning stranegy do you prefer?",
|
||||||
|
["uniform", "quantile"],
|
||||||
|
)
|
||||||
|
|
||||||
|
recommended_n_bins = int(np.sqrt(len(y_pred))) if strategy == "quantile" else 20 + 1
|
||||||
|
n_bins = st.number_input(
|
||||||
|
"How many bins do you want me to display?",
|
||||||
|
min_value=1,
|
||||||
|
value=recommended_n_bins,
|
||||||
|
)
|
||||||
|
|
||||||
|
fig = plotly_calibration(y_true, y_pred, n_bins=n_bins, strategy=strategy)
|
||||||
|
st.plotly_chart(fig, use_container_width=True)
|
||||||
|
|
||||||
|
overconf = overconfidence(y_true, y_pred)
|
||||||
|
st.write(f"Your over/under- confidence score is {overconf:.2f}.")
|
||||||
|
|
||||||
|
# ---
|
||||||
|
|
||||||
|
fig = plotly_calibration_odds(y_true, y_pred, n_bins=n_bins, strategy=strategy)
|
||||||
|
st.plotly_chart(fig, use_container_width=True)
|
Loading…
Reference in New Issue
Block a user