First commit: better late than never
This commit is contained in:
commit
a0a8093e12
38
calibration.py
Normal file
38
calibration.py
Normal file
|
@ -0,0 +1,38 @@
|
|||
import numpy as np
|
||||
|
||||
|
||||
# This function is a sklearn.calibration.calibration_curve modification
|
||||
def calibration_curve(y_true, y_prob, *, n_bins=5, strategy="uniform"):
|
||||
y_true = np.array(y_true)
|
||||
y_prob = np.array(y_prob)
|
||||
|
||||
if strategy == "quantile": # Determine bin edges by distribution of data
|
||||
quantiles = np.linspace(0, 1, n_bins + 1)
|
||||
bins = np.percentile(y_prob, quantiles * 100)
|
||||
bins[-1] = bins[-1] + 1e-8
|
||||
|
||||
elif strategy == "uniform":
|
||||
bins = np.linspace(0.0, 1.0 + 1e-8, n_bins + 1)
|
||||
|
||||
else:
|
||||
raise ValueError(
|
||||
"Invalid entry to 'strategy' input. Strategy "
|
||||
"must be either 'quantile' or 'uniform'."
|
||||
)
|
||||
|
||||
binids = np.digitize(y_prob, bins) - 1
|
||||
|
||||
bin_sums = np.bincount(binids, weights=y_prob, minlength=len(bins))
|
||||
bin_true = np.bincount(binids, weights=y_true, minlength=len(bins))
|
||||
bin_total = np.bincount(binids, minlength=len(bins))
|
||||
|
||||
nonzero = bin_total != 0
|
||||
prob_true = bin_true[nonzero] / bin_total[nonzero]
|
||||
prob_pred = bin_sums[nonzero] / bin_total[nonzero]
|
||||
|
||||
return prob_true, prob_pred, bin_total[nonzero]
|
||||
|
||||
|
||||
def overconfidence(y_true, y_pred):
|
||||
x = y_pred * y_true + (1 - y_pred) * (1 - y_true)
|
||||
return np.mean((x - 1) * (x - 0.5)) / np.mean((x - 0.5) * (x - 0.5))
|
43
firebase_requests.py
Normal file
43
firebase_requests.py
Normal file
|
@ -0,0 +1,43 @@
|
|||
import json
|
||||
|
||||
from gjo_requests import request_forecasts, request_resolutions
|
||||
from google.cloud import firestore
|
||||
|
||||
firestore_info = json.loads(st.secrets["firestore_info"])
|
||||
credentials = service_account.Credentials.from_service_account_info(firestore_info)
|
||||
db = firestore.Client(credentials=credentials, project="gjo-calibration")
|
||||
|
||||
|
||||
def get_forecasts(uid, questions, platform_url, headers, cookies):
|
||||
db_forecasts = db.collection("users").document(uid).get().to_dict()
|
||||
db_forecasts = dict() if db_forecasts is None else db_forecasts
|
||||
|
||||
missing_forecasts_qs = list(set(questions) - set(db_forecasts))
|
||||
missing_forecasts = request_forecasts(
|
||||
uid, missing_forecasts_qs, platform_url, headers, cookies
|
||||
)
|
||||
|
||||
if missing_forecasts:
|
||||
if not db_forecasts:
|
||||
db.collection("users").add({}, uid)
|
||||
db.collection("users").document(uid).update(missing_forecasts)
|
||||
|
||||
return {**db_forecasts, **missing_forecasts}
|
||||
|
||||
|
||||
def get_resolutions(questions, platform_url, headers, cookies):
|
||||
db_resolutions = db.collection("questions").document("resolutions").get().to_dict()
|
||||
db_resolutions = dict() if db_resolutions is None else db_resolutions
|
||||
relevant_resolutions = {
|
||||
key: value for key, value in db_resolutions.items() if key in set(questions)
|
||||
}
|
||||
|
||||
missing_resolutions_qs = list(set(questions) - set(relevant_resolutions))
|
||||
missing_resolutions = request_resolutions(
|
||||
missing_resolutions_qs, platform_url, headers, cookies
|
||||
)
|
||||
|
||||
if missing_resolutions:
|
||||
db.collection("questions").document("resolutions").update(missing_resolutions)
|
||||
|
||||
return {**relevant_resolutions, **missing_resolutions}
|
175
gjo_requests.py
Normal file
175
gjo_requests.py
Normal file
|
@ -0,0 +1,175 @@
|
|||
import asyncio
|
||||
import logging
|
||||
import re
|
||||
from itertools import count
|
||||
|
||||
import aiohttp
|
||||
import aioitertools
|
||||
import requests
|
||||
import streamlit as st
|
||||
from bs4 import BeautifulSoup
|
||||
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
|
||||
@st.cache
|
||||
def get_resolved_questions(uid, platform_url, headers, cookies):
|
||||
logging.info(
|
||||
f"[ ] get_resolved_questions for uid={uid}, platform_url={platform_url}"
|
||||
)
|
||||
|
||||
questions = [] # [question_id]
|
||||
|
||||
for page_num in count(1):
|
||||
url = f"{platform_url}/memberships/{uid}/scores/?page={page_num}"
|
||||
page = requests.get(url, headers=headers, cookies=cookies).text
|
||||
|
||||
extracted_qs = re.findall("/questions/(\d+)", page)
|
||||
questions.extend(extracted_qs)
|
||||
|
||||
if not extracted_qs:
|
||||
break
|
||||
|
||||
logging.info(
|
||||
f"[X] get_resolved_questions for uid={uid}, platform_url={platform_url}"
|
||||
)
|
||||
|
||||
return questions
|
||||
|
||||
|
||||
async def get_question_resolution(qid, platform_url, session):
|
||||
logging.info(
|
||||
f"[ ] get_question_resolution for uid={uid}, platform_url={platform_url}"
|
||||
)
|
||||
|
||||
url = f"{platform_url}/questions/{qid}"
|
||||
|
||||
async with session.get(url) as resp:
|
||||
if resp.status != 200:
|
||||
logging.error(
|
||||
f"get_question_resolution for uid={uid}, platform_url={platform_url} | "
|
||||
f"resp.status == {resp.status} → {resp.reason}"
|
||||
)
|
||||
|
||||
page = await resp.text()
|
||||
|
||||
soup = BeautifulSoup(page, "html.parser")
|
||||
soup = soup.find_all("div", {"id": "prediction-interface-container"})[0]
|
||||
|
||||
binary = soup.find_all("div", {"class": "binary-probability-value"})
|
||||
if binary:
|
||||
y_true = (0, 1) if re.search("Yes", binary[1].text) is None else (1, 0)
|
||||
else:
|
||||
tables = soup.find_all("table")
|
||||
y_true = tuple(len(tr.findAll("i")) for tr in tables[0].findAll("tr")[1:])
|
||||
|
||||
logging.info(
|
||||
f"[X] get_question_resolution for uid={uid}, platform_url={platform_url}"
|
||||
)
|
||||
return {"y_true": y_true}
|
||||
|
||||
|
||||
def _extract_forecasts_from_page(page):
|
||||
soup = BeautifulSoup(page, "html.parser")
|
||||
soup_predictions = soup.find_all("div", {"class": "prediction-values"})
|
||||
predictions = [re.findall("\n\s*(\d+)%", p_tag.text) for p_tag in soup_predictions]
|
||||
predictions = [tuple(int(prob) / 100 for prob in pred) for pred in predictions]
|
||||
predictions = [
|
||||
(pred[0], 1 - pred[0]) if len(pred) == 1 else pred for pred in predictions
|
||||
]
|
||||
|
||||
# I search for a line containing "made a forecast"
|
||||
# I search for the next line containig <span data-localizable-timestamp="[^"]*">
|
||||
# And graab a timestamp from it
|
||||
timestamps = []
|
||||
looking_for_a_forecast = True
|
||||
for line in page.split("\n"):
|
||||
if looking_for_a_forecast:
|
||||
hit = re.findall("made a forecast", line)
|
||||
if hit:
|
||||
looking_for_a_forecast = False
|
||||
|
||||
else:
|
||||
hit = re.findall('<span data-localizable-timestamp="([^"]+)">', line)
|
||||
if hit:
|
||||
timestamps.extend(hit)
|
||||
looking_for_a_forecast = True
|
||||
|
||||
if len(timestamps) != len(predictions):
|
||||
logging.error(
|
||||
f"In _extract_forecasts_from_page with uid={uid}, qid={qid}, page_num={page_num} "
|
||||
f"got different number of predictions ({len(timestamps)}) and timestamps ({len(predictions)})."
|
||||
)
|
||||
|
||||
return [
|
||||
{"y_pred": pred, "timestamp": timestamp}
|
||||
for pred, timestamp in zip(predictions, timestamps)
|
||||
]
|
||||
|
||||
|
||||
async def get_forecasts_on_the_question(uid, qid, platform_url, session):
|
||||
logging.info(
|
||||
f"[ ] get_forecasts_on_the_question for uid={uid}, qid={qid}, platform_url={platform_url}"
|
||||
)
|
||||
|
||||
forecasts = [] # [{"y_pred": (probs, ...), "timestamp": timestamp}, ...]
|
||||
|
||||
for page_num in count(1):
|
||||
url = f"{platform_url}/questions/{qid}/prediction_sets?membership_id={uid}&page={page_num}"
|
||||
|
||||
async with session.get(url) as resp:
|
||||
if resp.status != 200:
|
||||
logging.error(
|
||||
f"get_forecasts_on_the_question for uid={uid}, qid={qid}, platform_url={platform_url} | "
|
||||
f"resp.status == {resp.status} → {resp.reason}"
|
||||
)
|
||||
|
||||
page = await resp.text()
|
||||
|
||||
extracted_forecasts = _extract_forecasts_from_page(page)
|
||||
forecasts.extend(extracted_forecasts)
|
||||
|
||||
if not extracted_forecasts:
|
||||
break
|
||||
|
||||
logging.info(
|
||||
f"[X] get_forecasts_on_the_question for uid={uid}, qid={qid}, platform_url={platform_url}"
|
||||
)
|
||||
return forecasts
|
||||
|
||||
|
||||
# ---
|
||||
|
||||
|
||||
async def async_get_forecasts(uid, questions, platform_url, headers, cookies):
|
||||
async with aiohttp.ClientSession(headers=headers, cookies=cookies) as session:
|
||||
forecasts_list = await aioitertools.asyncio.gather(
|
||||
*[
|
||||
get_forecasts_on_the_question(uid, q, platform_url, session)
|
||||
for q in questions
|
||||
],
|
||||
limit=5,
|
||||
)
|
||||
return {q: forecasts_list[i] for i, q in enumerate(questions)}
|
||||
|
||||
|
||||
async def async_get_resolutions(questions, platform_url, headers, cookies):
|
||||
async with aiohttp.ClientSession(headers=headers, cookies=cookies) as session:
|
||||
resolutions_list = await aioitertools.asyncio.gather(
|
||||
*[get_question_resolution(q, platform_url, session) for q in questions],
|
||||
limit=5,
|
||||
)
|
||||
return {q: resolutions_list[i] for i, q in enumerate(questions)}
|
||||
|
||||
|
||||
def request_forecasts(uid, missing_forecasts_qs, platform_url, headers, cookies):
|
||||
return asyncio.run(
|
||||
async_get_forecasts(uid, missing_forecasts_qs, platform_url, headers, cookies)
|
||||
)
|
||||
|
||||
|
||||
def request_resolutions(missing_resolutions_qs, platform_url, headers, cookies):
|
||||
return asyncio.run(
|
||||
async_get_resolutions(missing_resolutions_qs, platform_url, headers, cookies)
|
||||
)
|
172
plotting.py
Normal file
172
plotting.py
Normal file
|
@ -0,0 +1,172 @@
|
|||
import numpy as np
|
||||
import plotly.graph_objects as go
|
||||
from calibration import calibration_curve
|
||||
|
||||
|
||||
def plotly_calibration(y_true, y_pred, n_bins, strategy="quantile"):
|
||||
fraction_of_positives, mean_predicted_value, counts = calibration_curve(
|
||||
y_true, y_pred, n_bins=n_bins, strategy=strategy
|
||||
)
|
||||
error_y = np.sqrt((fraction_of_positives) * (1 - fraction_of_positives) / counts)
|
||||
|
||||
fig = go.Figure()
|
||||
|
||||
fig.add_trace(
|
||||
go.Scatter(
|
||||
x=mean_predicted_value,
|
||||
y=fraction_of_positives,
|
||||
customdata=counts,
|
||||
mode="markers",
|
||||
error_y=dict(
|
||||
type="data",
|
||||
array=error_y,
|
||||
thickness=1.5,
|
||||
width=3,
|
||||
),
|
||||
hovertemplate="<br>".join(
|
||||
[
|
||||
"x: %{x:.3f}",
|
||||
"y: %{y:.3f}",
|
||||
"N: %{customdata}",
|
||||
"<extra></extra>",
|
||||
]
|
||||
),
|
||||
showlegend=False,
|
||||
)
|
||||
)
|
||||
|
||||
fig.add_shape(
|
||||
type="line",
|
||||
x0=0,
|
||||
y0=0,
|
||||
x1=1,
|
||||
y1=1,
|
||||
line=dict(
|
||||
color="LightSeaGreen",
|
||||
width=2,
|
||||
dash="dot",
|
||||
),
|
||||
opacity=0.5,
|
||||
)
|
||||
|
||||
fig.update_layout(
|
||||
width=800,
|
||||
height=800,
|
||||
title="Calibration plot",
|
||||
xaxis_title="Mean predicted value",
|
||||
yaxis_title="Fraction of positives (± std)",
|
||||
)
|
||||
|
||||
fig.update_xaxes(
|
||||
range=[-0.05, 1.05],
|
||||
constrain="domain",
|
||||
)
|
||||
|
||||
fig.update_yaxes(
|
||||
range=[-0.05, 1.05],
|
||||
constrain="domain",
|
||||
scaleanchor="x",
|
||||
scaleratio=1,
|
||||
)
|
||||
|
||||
return fig
|
||||
|
||||
|
||||
def plotly_calibration_odds(y_true, y_pred, n_bins, strategy="quantile"):
|
||||
y_pred = np.clip(y_pred, 0.005, 0.995) # clipping to avoid undefined odds
|
||||
y_true = np.clip(y_true, 1e-3, 1 - 1e-3)
|
||||
fraction_of_positives, mean_predicted_value, counts = calibration_curve(
|
||||
y_true, y_pred, n_bins=n_bins, strategy=strategy
|
||||
)
|
||||
error_y = np.sqrt((fraction_of_positives) * (1 - fraction_of_positives) / counts)
|
||||
|
||||
fig = go.Figure()
|
||||
|
||||
transform = lambda x: np.log2(1 / (1 - x) - 1) # 66.6% → 2^{1}:1 → 1
|
||||
|
||||
customdata = np.dstack(
|
||||
[
|
||||
counts,
|
||||
[
|
||||
f"{2**x:.1f} : 1" if x > 0 else f"1 : {2**-x:.1f}"
|
||||
for x in transform(mean_predicted_value)
|
||||
],
|
||||
[
|
||||
f"{2**x:.1f} : 1" if x > 0 else f"1 : {2**-x:.1f}"
|
||||
for x in transform(fraction_of_positives)
|
||||
],
|
||||
]
|
||||
).squeeze()
|
||||
|
||||
fig.add_trace(
|
||||
go.Scatter(
|
||||
x=transform(mean_predicted_value),
|
||||
y=transform(fraction_of_positives),
|
||||
customdata=customdata,
|
||||
mode="markers",
|
||||
error_y=dict(
|
||||
type="data",
|
||||
symmetric=False,
|
||||
array=transform(fraction_of_positives + error_y)
|
||||
- transform(fraction_of_positives),
|
||||
arrayminus=transform(fraction_of_positives)
|
||||
- transform(fraction_of_positives - error_y),
|
||||
thickness=1.5,
|
||||
width=3,
|
||||
),
|
||||
hovertemplate="<br>".join(
|
||||
[
|
||||
"x: %{customdata[1]}",
|
||||
"y: %{customdata[2]}",
|
||||
"N: %{customdata[0]}",
|
||||
"<extra></extra>",
|
||||
]
|
||||
),
|
||||
showlegend=False,
|
||||
)
|
||||
)
|
||||
|
||||
fig.add_shape(
|
||||
type="line",
|
||||
x0=-8,
|
||||
y0=-8,
|
||||
x1=8,
|
||||
y1=8,
|
||||
line=dict(
|
||||
color="LightSeaGreen",
|
||||
width=2,
|
||||
dash="dot",
|
||||
),
|
||||
opacity=0.5,
|
||||
)
|
||||
|
||||
fig.update_layout(
|
||||
width=800,
|
||||
height=800,
|
||||
title="Calibration plot in terms of odds",
|
||||
xaxis_title="Mean predicted value",
|
||||
yaxis_title="Fraction of positives (± std)",
|
||||
)
|
||||
|
||||
fig.update_xaxes(
|
||||
range=[-8, 8],
|
||||
constrain="domain",
|
||||
tickmode="array",
|
||||
tickvals=list(range(-10, 10)),
|
||||
ticktext=[
|
||||
f"{2**x} : 1" if x > 0 else f"1 : {2**-x}" for x in list(range(-10, 10))
|
||||
],
|
||||
)
|
||||
|
||||
fig.update_yaxes(
|
||||
range=[-8, 8],
|
||||
constrain="domain",
|
||||
scaleanchor="x",
|
||||
scaleratio=1,
|
||||
tickvals=list(range(-10, 10)),
|
||||
ticktext=[
|
||||
f"{2**x} : 1" if x > 0 else f"1 : {2**-x}" for x in list(range(-10, 10))
|
||||
],
|
||||
)
|
||||
|
||||
return fig
|
24
requirements.txt
Normal file
24
requirements.txt
Normal file
|
@ -0,0 +1,24 @@
|
|||
lxml==4.6.1
|
||||
lockfile==0.12.2
|
||||
numpy==1.19.3
|
||||
keyring==21.5.0
|
||||
mypy_extensions==0.4.3
|
||||
pandas==1.1.3
|
||||
typing_extensions==3.7.4.3
|
||||
aiohttp==3.7.4.post0
|
||||
aioitertools==0.7.1
|
||||
beautifulsoup4==4.9.3
|
||||
brotli==1.0.9
|
||||
cryptography==3.4.7
|
||||
Cython==0.29.23
|
||||
docutils==0.17.1
|
||||
importlib_metadata==4.4.0
|
||||
ipaddr==2.2.0
|
||||
ordereddict==1.1
|
||||
plotly==4.14.3
|
||||
protobuf==3.17.1
|
||||
pyOpenSSL==20.0.1
|
||||
streamlit==0.82.0
|
||||
uncurl==0.0.11
|
||||
wincertstore==0.2
|
||||
zipp==3.4.1
|
100
strmlt.py
Normal file
100
strmlt.py
Normal file
|
@ -0,0 +1,100 @@
|
|||
import numpy as np
|
||||
import pandas as pd
|
||||
import streamlit as st
|
||||
import uncurl
|
||||
from calibration import overconfidence
|
||||
from firebase_requests import get_forecasts, get_resolutions
|
||||
from gjo_requests import get_resolved_questions
|
||||
from plotting import plotly_calibration, plotly_calibration_odds
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
st.title("Learn how calibrated are you?")
|
||||
|
||||
# ---
|
||||
|
||||
# if st.checkbox('I am new! Show me instructions.'):
|
||||
# st.write("""
|
||||
# Hey!
|
||||
# """)
|
||||
|
||||
# ---
|
||||
|
||||
platform = st.selectbox(
|
||||
"Which platform are you using?",
|
||||
["Good Judgement Open", "CSET Foretell"],
|
||||
)
|
||||
platform_url = {
|
||||
"Good Judgement Open": "https://www.gjopen.com",
|
||||
"CSET Foretell": "https://www.cset-foretell.com",
|
||||
}[platform]
|
||||
|
||||
uid = st.number_input("What is your user ID?", min_value=1, value=28899)
|
||||
uid = str(uid)
|
||||
|
||||
curl_value = ""
|
||||
|
||||
curl_command = st.text_area(
|
||||
"Ugh... Gimme your cURL info...", value=curl_value.strip()
|
||||
)
|
||||
curl_content = uncurl.parse_context(curl_command)
|
||||
headers, cookies = curl_content.headers, curl_content.cookies
|
||||
|
||||
# ---
|
||||
|
||||
questions = get_resolved_questions(uid, platform_url, headers, cookies)
|
||||
|
||||
st.write(f"{len(questions)} questions you forecasted on have resolved.")
|
||||
|
||||
# ---
|
||||
# TODO: Make a progress bar..?
|
||||
|
||||
forecasts = get_forecasts(uid, questions, platform_url, headers, cookies)
|
||||
resolutions = get_resolutions(questions, platform_url, headers, cookies)
|
||||
|
||||
# ---
|
||||
|
||||
num_forecasts = sum(len(f) for f in forecasts.values())
|
||||
st.write(
|
||||
f"On these {len(questions)} questions you've made {num_forecasts} forecasts."
|
||||
)
|
||||
|
||||
flatten = lambda t: [item for sublist in t for item in sublist]
|
||||
y_true = flatten(resolutions[q]["y_true"] for q in questions for _ in forecasts[q])
|
||||
y_pred = flatten(f["y_pred"] for q in questions for f in forecasts[q])
|
||||
|
||||
# Note that I am "double counting" each prediction.
|
||||
if st.checkbox("Drop last"):
|
||||
y_true = flatten(
|
||||
resolutions[q]["y_true"][:-1] for q in questions for _ in forecasts[q]
|
||||
)
|
||||
y_pred = flatten(f["y_pred"][:-1] for q in questions for f in forecasts[q])
|
||||
|
||||
y_true, y_pred = np.array(y_true), np.array(y_pred)
|
||||
|
||||
st.write(f"Which gives us {len(y_pred)} datapoints to work with.")
|
||||
|
||||
# ---
|
||||
|
||||
strategy = st.selectbox(
|
||||
"Which binning stranegy do you prefer?",
|
||||
["uniform", "quantile"],
|
||||
)
|
||||
|
||||
recommended_n_bins = int(np.sqrt(len(y_pred))) if strategy == "quantile" else 20 + 1
|
||||
n_bins = st.number_input(
|
||||
"How many bins do you want me to display?",
|
||||
min_value=1,
|
||||
value=recommended_n_bins,
|
||||
)
|
||||
|
||||
fig = plotly_calibration(y_true, y_pred, n_bins=n_bins, strategy=strategy)
|
||||
st.plotly_chart(fig, use_container_width=True)
|
||||
|
||||
overconf = overconfidence(y_true, y_pred)
|
||||
st.write(f"Your over/under- confidence score is {overconf:.2f}.")
|
||||
|
||||
# ---
|
||||
|
||||
fig = plotly_calibration_odds(y_true, y_pred, n_bins=n_bins, strategy=strategy)
|
||||
st.plotly_chart(fig, use_container_width=True)
|
Loading…
Reference in New Issue
Block a user