First commit: better late than never

This commit is contained in:
yagudin 2021-05-31 21:59:24 +03:00
commit a0a8093e12
6 changed files with 552 additions and 0 deletions

38
calibration.py Normal file
View File

@ -0,0 +1,38 @@
import numpy as np
# This function is a sklearn.calibration.calibration_curve modification
def calibration_curve(y_true, y_prob, *, n_bins=5, strategy="uniform"):
y_true = np.array(y_true)
y_prob = np.array(y_prob)
if strategy == "quantile": # Determine bin edges by distribution of data
quantiles = np.linspace(0, 1, n_bins + 1)
bins = np.percentile(y_prob, quantiles * 100)
bins[-1] = bins[-1] + 1e-8
elif strategy == "uniform":
bins = np.linspace(0.0, 1.0 + 1e-8, n_bins + 1)
else:
raise ValueError(
"Invalid entry to 'strategy' input. Strategy "
"must be either 'quantile' or 'uniform'."
)
binids = np.digitize(y_prob, bins) - 1
bin_sums = np.bincount(binids, weights=y_prob, minlength=len(bins))
bin_true = np.bincount(binids, weights=y_true, minlength=len(bins))
bin_total = np.bincount(binids, minlength=len(bins))
nonzero = bin_total != 0
prob_true = bin_true[nonzero] / bin_total[nonzero]
prob_pred = bin_sums[nonzero] / bin_total[nonzero]
return prob_true, prob_pred, bin_total[nonzero]
def overconfidence(y_true, y_pred):
x = y_pred * y_true + (1 - y_pred) * (1 - y_true)
return np.mean((x - 1) * (x - 0.5)) / np.mean((x - 0.5) * (x - 0.5))

43
firebase_requests.py Normal file
View File

@ -0,0 +1,43 @@
import json
from gjo_requests import request_forecasts, request_resolutions
from google.cloud import firestore
firestore_info = json.loads(st.secrets["firestore_info"])
credentials = service_account.Credentials.from_service_account_info(firestore_info)
db = firestore.Client(credentials=credentials, project="gjo-calibration")
def get_forecasts(uid, questions, platform_url, headers, cookies):
db_forecasts = db.collection("users").document(uid).get().to_dict()
db_forecasts = dict() if db_forecasts is None else db_forecasts
missing_forecasts_qs = list(set(questions) - set(db_forecasts))
missing_forecasts = request_forecasts(
uid, missing_forecasts_qs, platform_url, headers, cookies
)
if missing_forecasts:
if not db_forecasts:
db.collection("users").add({}, uid)
db.collection("users").document(uid).update(missing_forecasts)
return {**db_forecasts, **missing_forecasts}
def get_resolutions(questions, platform_url, headers, cookies):
db_resolutions = db.collection("questions").document("resolutions").get().to_dict()
db_resolutions = dict() if db_resolutions is None else db_resolutions
relevant_resolutions = {
key: value for key, value in db_resolutions.items() if key in set(questions)
}
missing_resolutions_qs = list(set(questions) - set(relevant_resolutions))
missing_resolutions = request_resolutions(
missing_resolutions_qs, platform_url, headers, cookies
)
if missing_resolutions:
db.collection("questions").document("resolutions").update(missing_resolutions)
return {**relevant_resolutions, **missing_resolutions}

175
gjo_requests.py Normal file
View File

@ -0,0 +1,175 @@
import asyncio
import logging
import re
from itertools import count
import aiohttp
import aioitertools
import requests
import streamlit as st
from bs4 import BeautifulSoup
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
@st.cache
def get_resolved_questions(uid, platform_url, headers, cookies):
logging.info(
f"[ ] get_resolved_questions for uid={uid}, platform_url={platform_url}"
)
questions = [] # [question_id]
for page_num in count(1):
url = f"{platform_url}/memberships/{uid}/scores/?page={page_num}"
page = requests.get(url, headers=headers, cookies=cookies).text
extracted_qs = re.findall("/questions/(\d+)", page)
questions.extend(extracted_qs)
if not extracted_qs:
break
logging.info(
f"[X] get_resolved_questions for uid={uid}, platform_url={platform_url}"
)
return questions
async def get_question_resolution(qid, platform_url, session):
logging.info(
f"[ ] get_question_resolution for uid={uid}, platform_url={platform_url}"
)
url = f"{platform_url}/questions/{qid}"
async with session.get(url) as resp:
if resp.status != 200:
logging.error(
f"get_question_resolution for uid={uid}, platform_url={platform_url} | "
f"resp.status == {resp.status}{resp.reason}"
)
page = await resp.text()
soup = BeautifulSoup(page, "html.parser")
soup = soup.find_all("div", {"id": "prediction-interface-container"})[0]
binary = soup.find_all("div", {"class": "binary-probability-value"})
if binary:
y_true = (0, 1) if re.search("Yes", binary[1].text) is None else (1, 0)
else:
tables = soup.find_all("table")
y_true = tuple(len(tr.findAll("i")) for tr in tables[0].findAll("tr")[1:])
logging.info(
f"[X] get_question_resolution for uid={uid}, platform_url={platform_url}"
)
return {"y_true": y_true}
def _extract_forecasts_from_page(page):
soup = BeautifulSoup(page, "html.parser")
soup_predictions = soup.find_all("div", {"class": "prediction-values"})
predictions = [re.findall("\n\s*(\d+)%", p_tag.text) for p_tag in soup_predictions]
predictions = [tuple(int(prob) / 100 for prob in pred) for pred in predictions]
predictions = [
(pred[0], 1 - pred[0]) if len(pred) == 1 else pred for pred in predictions
]
# I search for a line containing "made a forecast"
# I search for the next line containig <span data-localizable-timestamp="[^"]*">
# And graab a timestamp from it
timestamps = []
looking_for_a_forecast = True
for line in page.split("\n"):
if looking_for_a_forecast:
hit = re.findall("made a forecast", line)
if hit:
looking_for_a_forecast = False
else:
hit = re.findall('<span data-localizable-timestamp="([^"]+)">', line)
if hit:
timestamps.extend(hit)
looking_for_a_forecast = True
if len(timestamps) != len(predictions):
logging.error(
f"In _extract_forecasts_from_page with uid={uid}, qid={qid}, page_num={page_num} "
f"got different number of predictions ({len(timestamps)}) and timestamps ({len(predictions)})."
)
return [
{"y_pred": pred, "timestamp": timestamp}
for pred, timestamp in zip(predictions, timestamps)
]
async def get_forecasts_on_the_question(uid, qid, platform_url, session):
logging.info(
f"[ ] get_forecasts_on_the_question for uid={uid}, qid={qid}, platform_url={platform_url}"
)
forecasts = [] # [{"y_pred": (probs, ...), "timestamp": timestamp}, ...]
for page_num in count(1):
url = f"{platform_url}/questions/{qid}/prediction_sets?membership_id={uid}&page={page_num}"
async with session.get(url) as resp:
if resp.status != 200:
logging.error(
f"get_forecasts_on_the_question for uid={uid}, qid={qid}, platform_url={platform_url} | "
f"resp.status == {resp.status}{resp.reason}"
)
page = await resp.text()
extracted_forecasts = _extract_forecasts_from_page(page)
forecasts.extend(extracted_forecasts)
if not extracted_forecasts:
break
logging.info(
f"[X] get_forecasts_on_the_question for uid={uid}, qid={qid}, platform_url={platform_url}"
)
return forecasts
# ---
async def async_get_forecasts(uid, questions, platform_url, headers, cookies):
async with aiohttp.ClientSession(headers=headers, cookies=cookies) as session:
forecasts_list = await aioitertools.asyncio.gather(
*[
get_forecasts_on_the_question(uid, q, platform_url, session)
for q in questions
],
limit=5,
)
return {q: forecasts_list[i] for i, q in enumerate(questions)}
async def async_get_resolutions(questions, platform_url, headers, cookies):
async with aiohttp.ClientSession(headers=headers, cookies=cookies) as session:
resolutions_list = await aioitertools.asyncio.gather(
*[get_question_resolution(q, platform_url, session) for q in questions],
limit=5,
)
return {q: resolutions_list[i] for i, q in enumerate(questions)}
def request_forecasts(uid, missing_forecasts_qs, platform_url, headers, cookies):
return asyncio.run(
async_get_forecasts(uid, missing_forecasts_qs, platform_url, headers, cookies)
)
def request_resolutions(missing_resolutions_qs, platform_url, headers, cookies):
return asyncio.run(
async_get_resolutions(missing_resolutions_qs, platform_url, headers, cookies)
)

172
plotting.py Normal file
View File

@ -0,0 +1,172 @@
import numpy as np
import plotly.graph_objects as go
from calibration import calibration_curve
def plotly_calibration(y_true, y_pred, n_bins, strategy="quantile"):
fraction_of_positives, mean_predicted_value, counts = calibration_curve(
y_true, y_pred, n_bins=n_bins, strategy=strategy
)
error_y = np.sqrt((fraction_of_positives) * (1 - fraction_of_positives) / counts)
fig = go.Figure()
fig.add_trace(
go.Scatter(
x=mean_predicted_value,
y=fraction_of_positives,
customdata=counts,
mode="markers",
error_y=dict(
type="data",
array=error_y,
thickness=1.5,
width=3,
),
hovertemplate="<br>".join(
[
"x: %{x:.3f}",
"y: %{y:.3f}",
"N: %{customdata}",
"<extra></extra>",
]
),
showlegend=False,
)
)
fig.add_shape(
type="line",
x0=0,
y0=0,
x1=1,
y1=1,
line=dict(
color="LightSeaGreen",
width=2,
dash="dot",
),
opacity=0.5,
)
fig.update_layout(
width=800,
height=800,
title="Calibration plot",
xaxis_title="Mean predicted value",
yaxis_title="Fraction of positives (± std)",
)
fig.update_xaxes(
range=[-0.05, 1.05],
constrain="domain",
)
fig.update_yaxes(
range=[-0.05, 1.05],
constrain="domain",
scaleanchor="x",
scaleratio=1,
)
return fig
def plotly_calibration_odds(y_true, y_pred, n_bins, strategy="quantile"):
y_pred = np.clip(y_pred, 0.005, 0.995) # clipping to avoid undefined odds
y_true = np.clip(y_true, 1e-3, 1 - 1e-3)
fraction_of_positives, mean_predicted_value, counts = calibration_curve(
y_true, y_pred, n_bins=n_bins, strategy=strategy
)
error_y = np.sqrt((fraction_of_positives) * (1 - fraction_of_positives) / counts)
fig = go.Figure()
transform = lambda x: np.log2(1 / (1 - x) - 1) # 66.6% → 2^{1}:1 → 1
customdata = np.dstack(
[
counts,
[
f"{2**x:.1f} : 1" if x > 0 else f"1 : {2**-x:.1f}"
for x in transform(mean_predicted_value)
],
[
f"{2**x:.1f} : 1" if x > 0 else f"1 : {2**-x:.1f}"
for x in transform(fraction_of_positives)
],
]
).squeeze()
fig.add_trace(
go.Scatter(
x=transform(mean_predicted_value),
y=transform(fraction_of_positives),
customdata=customdata,
mode="markers",
error_y=dict(
type="data",
symmetric=False,
array=transform(fraction_of_positives + error_y)
- transform(fraction_of_positives),
arrayminus=transform(fraction_of_positives)
- transform(fraction_of_positives - error_y),
thickness=1.5,
width=3,
),
hovertemplate="<br>".join(
[
"x: %{customdata[1]}",
"y: %{customdata[2]}",
"N: %{customdata[0]}",
"<extra></extra>",
]
),
showlegend=False,
)
)
fig.add_shape(
type="line",
x0=-8,
y0=-8,
x1=8,
y1=8,
line=dict(
color="LightSeaGreen",
width=2,
dash="dot",
),
opacity=0.5,
)
fig.update_layout(
width=800,
height=800,
title="Calibration plot in terms of odds",
xaxis_title="Mean predicted value",
yaxis_title="Fraction of positives (± std)",
)
fig.update_xaxes(
range=[-8, 8],
constrain="domain",
tickmode="array",
tickvals=list(range(-10, 10)),
ticktext=[
f"{2**x} : 1" if x > 0 else f"1 : {2**-x}" for x in list(range(-10, 10))
],
)
fig.update_yaxes(
range=[-8, 8],
constrain="domain",
scaleanchor="x",
scaleratio=1,
tickvals=list(range(-10, 10)),
ticktext=[
f"{2**x} : 1" if x > 0 else f"1 : {2**-x}" for x in list(range(-10, 10))
],
)
return fig

24
requirements.txt Normal file
View File

@ -0,0 +1,24 @@
lxml==4.6.1
lockfile==0.12.2
numpy==1.19.3
keyring==21.5.0
mypy_extensions==0.4.3
pandas==1.1.3
typing_extensions==3.7.4.3
aiohttp==3.7.4.post0
aioitertools==0.7.1
beautifulsoup4==4.9.3
brotli==1.0.9
cryptography==3.4.7
Cython==0.29.23
docutils==0.17.1
importlib_metadata==4.4.0
ipaddr==2.2.0
ordereddict==1.1
plotly==4.14.3
protobuf==3.17.1
pyOpenSSL==20.0.1
streamlit==0.82.0
uncurl==0.0.11
wincertstore==0.2
zipp==3.4.1

100
strmlt.py Normal file
View File

@ -0,0 +1,100 @@
import numpy as np
import pandas as pd
import streamlit as st
import uncurl
from calibration import overconfidence
from firebase_requests import get_forecasts, get_resolutions
from gjo_requests import get_resolved_questions
from plotting import plotly_calibration, plotly_calibration_odds
if __name__ == "__main__":
st.title("Learn how calibrated are you?")
# ---
# if st.checkbox('I am new! Show me instructions.'):
# st.write("""
# Hey!
# """)
# ---
platform = st.selectbox(
"Which platform are you using?",
["Good Judgement Open", "CSET Foretell"],
)
platform_url = {
"Good Judgement Open": "https://www.gjopen.com",
"CSET Foretell": "https://www.cset-foretell.com",
}[platform]
uid = st.number_input("What is your user ID?", min_value=1, value=28899)
uid = str(uid)
curl_value = ""
curl_command = st.text_area(
"Ugh... Gimme your cURL info...", value=curl_value.strip()
)
curl_content = uncurl.parse_context(curl_command)
headers, cookies = curl_content.headers, curl_content.cookies
# ---
questions = get_resolved_questions(uid, platform_url, headers, cookies)
st.write(f"{len(questions)} questions you forecasted on have resolved.")
# ---
# TODO: Make a progress bar..?
forecasts = get_forecasts(uid, questions, platform_url, headers, cookies)
resolutions = get_resolutions(questions, platform_url, headers, cookies)
# ---
num_forecasts = sum(len(f) for f in forecasts.values())
st.write(
f"On these {len(questions)} questions you've made {num_forecasts} forecasts."
)
flatten = lambda t: [item for sublist in t for item in sublist]
y_true = flatten(resolutions[q]["y_true"] for q in questions for _ in forecasts[q])
y_pred = flatten(f["y_pred"] for q in questions for f in forecasts[q])
# Note that I am "double counting" each prediction.
if st.checkbox("Drop last"):
y_true = flatten(
resolutions[q]["y_true"][:-1] for q in questions for _ in forecasts[q]
)
y_pred = flatten(f["y_pred"][:-1] for q in questions for f in forecasts[q])
y_true, y_pred = np.array(y_true), np.array(y_pred)
st.write(f"Which gives us {len(y_pred)} datapoints to work with.")
# ---
strategy = st.selectbox(
"Which binning stranegy do you prefer?",
["uniform", "quantile"],
)
recommended_n_bins = int(np.sqrt(len(y_pred))) if strategy == "quantile" else 20 + 1
n_bins = st.number_input(
"How many bins do you want me to display?",
min_value=1,
value=recommended_n_bins,
)
fig = plotly_calibration(y_true, y_pred, n_bins=n_bins, strategy=strategy)
st.plotly_chart(fig, use_container_width=True)
overconf = overconfidence(y_true, y_pred)
st.write(f"Your over/under- confidence score is {overconf:.2f}.")
# ---
fig = plotly_calibration_odds(y_true, y_pred, n_bins=n_bins, strategy=strategy)
st.plotly_chart(fig, use_container_width=True)