commit a0a8093e12d3f7e8a16ba480da479bed5ab2c037 Author: yagudin Date: Mon May 31 21:59:24 2021 +0300 First commit: better late than never diff --git a/calibration.py b/calibration.py new file mode 100644 index 0000000..d321ee3 --- /dev/null +++ b/calibration.py @@ -0,0 +1,38 @@ +import numpy as np + + +# This function is a sklearn.calibration.calibration_curve modification +def calibration_curve(y_true, y_prob, *, n_bins=5, strategy="uniform"): + y_true = np.array(y_true) + y_prob = np.array(y_prob) + + if strategy == "quantile": # Determine bin edges by distribution of data + quantiles = np.linspace(0, 1, n_bins + 1) + bins = np.percentile(y_prob, quantiles * 100) + bins[-1] = bins[-1] + 1e-8 + + elif strategy == "uniform": + bins = np.linspace(0.0, 1.0 + 1e-8, n_bins + 1) + + else: + raise ValueError( + "Invalid entry to 'strategy' input. Strategy " + "must be either 'quantile' or 'uniform'." + ) + + binids = np.digitize(y_prob, bins) - 1 + + bin_sums = np.bincount(binids, weights=y_prob, minlength=len(bins)) + bin_true = np.bincount(binids, weights=y_true, minlength=len(bins)) + bin_total = np.bincount(binids, minlength=len(bins)) + + nonzero = bin_total != 0 + prob_true = bin_true[nonzero] / bin_total[nonzero] + prob_pred = bin_sums[nonzero] / bin_total[nonzero] + + return prob_true, prob_pred, bin_total[nonzero] + + +def overconfidence(y_true, y_pred): + x = y_pred * y_true + (1 - y_pred) * (1 - y_true) + return np.mean((x - 1) * (x - 0.5)) / np.mean((x - 0.5) * (x - 0.5)) diff --git a/firebase_requests.py b/firebase_requests.py new file mode 100644 index 0000000..681d73e --- /dev/null +++ b/firebase_requests.py @@ -0,0 +1,43 @@ +import json + +from gjo_requests import request_forecasts, request_resolutions +from google.cloud import firestore + +firestore_info = json.loads(st.secrets["firestore_info"]) +credentials = service_account.Credentials.from_service_account_info(firestore_info) +db = firestore.Client(credentials=credentials, project="gjo-calibration") + + +def get_forecasts(uid, questions, platform_url, headers, cookies): + db_forecasts = db.collection("users").document(uid).get().to_dict() + db_forecasts = dict() if db_forecasts is None else db_forecasts + + missing_forecasts_qs = list(set(questions) - set(db_forecasts)) + missing_forecasts = request_forecasts( + uid, missing_forecasts_qs, platform_url, headers, cookies + ) + + if missing_forecasts: + if not db_forecasts: + db.collection("users").add({}, uid) + db.collection("users").document(uid).update(missing_forecasts) + + return {**db_forecasts, **missing_forecasts} + + +def get_resolutions(questions, platform_url, headers, cookies): + db_resolutions = db.collection("questions").document("resolutions").get().to_dict() + db_resolutions = dict() if db_resolutions is None else db_resolutions + relevant_resolutions = { + key: value for key, value in db_resolutions.items() if key in set(questions) + } + + missing_resolutions_qs = list(set(questions) - set(relevant_resolutions)) + missing_resolutions = request_resolutions( + missing_resolutions_qs, platform_url, headers, cookies + ) + + if missing_resolutions: + db.collection("questions").document("resolutions").update(missing_resolutions) + + return {**relevant_resolutions, **missing_resolutions} diff --git a/gjo_requests.py b/gjo_requests.py new file mode 100644 index 0000000..04d189e --- /dev/null +++ b/gjo_requests.py @@ -0,0 +1,175 @@ +import asyncio +import logging +import re +from itertools import count + +import aiohttp +import aioitertools +import requests +import streamlit as st +from bs4 import BeautifulSoup + +loop = asyncio.new_event_loop() +asyncio.set_event_loop(loop) + + +@st.cache +def get_resolved_questions(uid, platform_url, headers, cookies): + logging.info( + f"[ ] get_resolved_questions for uid={uid}, platform_url={platform_url}" + ) + + questions = [] # [question_id] + + for page_num in count(1): + url = f"{platform_url}/memberships/{uid}/scores/?page={page_num}" + page = requests.get(url, headers=headers, cookies=cookies).text + + extracted_qs = re.findall("/questions/(\d+)", page) + questions.extend(extracted_qs) + + if not extracted_qs: + break + + logging.info( + f"[X] get_resolved_questions for uid={uid}, platform_url={platform_url}" + ) + + return questions + + +async def get_question_resolution(qid, platform_url, session): + logging.info( + f"[ ] get_question_resolution for uid={uid}, platform_url={platform_url}" + ) + + url = f"{platform_url}/questions/{qid}" + + async with session.get(url) as resp: + if resp.status != 200: + logging.error( + f"get_question_resolution for uid={uid}, platform_url={platform_url} | " + f"resp.status == {resp.status} → {resp.reason}" + ) + + page = await resp.text() + + soup = BeautifulSoup(page, "html.parser") + soup = soup.find_all("div", {"id": "prediction-interface-container"})[0] + + binary = soup.find_all("div", {"class": "binary-probability-value"}) + if binary: + y_true = (0, 1) if re.search("Yes", binary[1].text) is None else (1, 0) + else: + tables = soup.find_all("table") + y_true = tuple(len(tr.findAll("i")) for tr in tables[0].findAll("tr")[1:]) + + logging.info( + f"[X] get_question_resolution for uid={uid}, platform_url={platform_url}" + ) + return {"y_true": y_true} + + +def _extract_forecasts_from_page(page): + soup = BeautifulSoup(page, "html.parser") + soup_predictions = soup.find_all("div", {"class": "prediction-values"}) + predictions = [re.findall("\n\s*(\d+)%", p_tag.text) for p_tag in soup_predictions] + predictions = [tuple(int(prob) / 100 for prob in pred) for pred in predictions] + predictions = [ + (pred[0], 1 - pred[0]) if len(pred) == 1 else pred for pred in predictions + ] + + # I search for a line containing "made a forecast" + # I search for the next line containig + # And graab a timestamp from it + timestamps = [] + looking_for_a_forecast = True + for line in page.split("\n"): + if looking_for_a_forecast: + hit = re.findall("made a forecast", line) + if hit: + looking_for_a_forecast = False + + else: + hit = re.findall('', line) + if hit: + timestamps.extend(hit) + looking_for_a_forecast = True + + if len(timestamps) != len(predictions): + logging.error( + f"In _extract_forecasts_from_page with uid={uid}, qid={qid}, page_num={page_num} " + f"got different number of predictions ({len(timestamps)}) and timestamps ({len(predictions)})." + ) + + return [ + {"y_pred": pred, "timestamp": timestamp} + for pred, timestamp in zip(predictions, timestamps) + ] + + +async def get_forecasts_on_the_question(uid, qid, platform_url, session): + logging.info( + f"[ ] get_forecasts_on_the_question for uid={uid}, qid={qid}, platform_url={platform_url}" + ) + + forecasts = [] # [{"y_pred": (probs, ...), "timestamp": timestamp}, ...] + + for page_num in count(1): + url = f"{platform_url}/questions/{qid}/prediction_sets?membership_id={uid}&page={page_num}" + + async with session.get(url) as resp: + if resp.status != 200: + logging.error( + f"get_forecasts_on_the_question for uid={uid}, qid={qid}, platform_url={platform_url} | " + f"resp.status == {resp.status} → {resp.reason}" + ) + + page = await resp.text() + + extracted_forecasts = _extract_forecasts_from_page(page) + forecasts.extend(extracted_forecasts) + + if not extracted_forecasts: + break + + logging.info( + f"[X] get_forecasts_on_the_question for uid={uid}, qid={qid}, platform_url={platform_url}" + ) + return forecasts + + +# --- + + +async def async_get_forecasts(uid, questions, platform_url, headers, cookies): + async with aiohttp.ClientSession(headers=headers, cookies=cookies) as session: + forecasts_list = await aioitertools.asyncio.gather( + *[ + get_forecasts_on_the_question(uid, q, platform_url, session) + for q in questions + ], + limit=5, + ) + return {q: forecasts_list[i] for i, q in enumerate(questions)} + + +async def async_get_resolutions(questions, platform_url, headers, cookies): + async with aiohttp.ClientSession(headers=headers, cookies=cookies) as session: + resolutions_list = await aioitertools.asyncio.gather( + *[get_question_resolution(q, platform_url, session) for q in questions], + limit=5, + ) + return {q: resolutions_list[i] for i, q in enumerate(questions)} + + +def request_forecasts(uid, missing_forecasts_qs, platform_url, headers, cookies): + return asyncio.run( + async_get_forecasts(uid, missing_forecasts_qs, platform_url, headers, cookies) + ) + + +def request_resolutions(missing_resolutions_qs, platform_url, headers, cookies): + return asyncio.run( + async_get_resolutions(missing_resolutions_qs, platform_url, headers, cookies) + ) diff --git a/plotting.py b/plotting.py new file mode 100644 index 0000000..cff47cf --- /dev/null +++ b/plotting.py @@ -0,0 +1,172 @@ +import numpy as np +import plotly.graph_objects as go +from calibration import calibration_curve + + +def plotly_calibration(y_true, y_pred, n_bins, strategy="quantile"): + fraction_of_positives, mean_predicted_value, counts = calibration_curve( + y_true, y_pred, n_bins=n_bins, strategy=strategy + ) + error_y = np.sqrt((fraction_of_positives) * (1 - fraction_of_positives) / counts) + + fig = go.Figure() + + fig.add_trace( + go.Scatter( + x=mean_predicted_value, + y=fraction_of_positives, + customdata=counts, + mode="markers", + error_y=dict( + type="data", + array=error_y, + thickness=1.5, + width=3, + ), + hovertemplate="
".join( + [ + "x: %{x:.3f}", + "y: %{y:.3f}", + "N: %{customdata}", + "", + ] + ), + showlegend=False, + ) + ) + + fig.add_shape( + type="line", + x0=0, + y0=0, + x1=1, + y1=1, + line=dict( + color="LightSeaGreen", + width=2, + dash="dot", + ), + opacity=0.5, + ) + + fig.update_layout( + width=800, + height=800, + title="Calibration plot", + xaxis_title="Mean predicted value", + yaxis_title="Fraction of positives (± std)", + ) + + fig.update_xaxes( + range=[-0.05, 1.05], + constrain="domain", + ) + + fig.update_yaxes( + range=[-0.05, 1.05], + constrain="domain", + scaleanchor="x", + scaleratio=1, + ) + + return fig + + +def plotly_calibration_odds(y_true, y_pred, n_bins, strategy="quantile"): + y_pred = np.clip(y_pred, 0.005, 0.995) # clipping to avoid undefined odds + y_true = np.clip(y_true, 1e-3, 1 - 1e-3) + fraction_of_positives, mean_predicted_value, counts = calibration_curve( + y_true, y_pred, n_bins=n_bins, strategy=strategy + ) + error_y = np.sqrt((fraction_of_positives) * (1 - fraction_of_positives) / counts) + + fig = go.Figure() + + transform = lambda x: np.log2(1 / (1 - x) - 1) # 66.6% → 2^{1}:1 → 1 + + customdata = np.dstack( + [ + counts, + [ + f"{2**x:.1f} : 1" if x > 0 else f"1 : {2**-x:.1f}" + for x in transform(mean_predicted_value) + ], + [ + f"{2**x:.1f} : 1" if x > 0 else f"1 : {2**-x:.1f}" + for x in transform(fraction_of_positives) + ], + ] + ).squeeze() + + fig.add_trace( + go.Scatter( + x=transform(mean_predicted_value), + y=transform(fraction_of_positives), + customdata=customdata, + mode="markers", + error_y=dict( + type="data", + symmetric=False, + array=transform(fraction_of_positives + error_y) + - transform(fraction_of_positives), + arrayminus=transform(fraction_of_positives) + - transform(fraction_of_positives - error_y), + thickness=1.5, + width=3, + ), + hovertemplate="
".join( + [ + "x: %{customdata[1]}", + "y: %{customdata[2]}", + "N: %{customdata[0]}", + "", + ] + ), + showlegend=False, + ) + ) + + fig.add_shape( + type="line", + x0=-8, + y0=-8, + x1=8, + y1=8, + line=dict( + color="LightSeaGreen", + width=2, + dash="dot", + ), + opacity=0.5, + ) + + fig.update_layout( + width=800, + height=800, + title="Calibration plot in terms of odds", + xaxis_title="Mean predicted value", + yaxis_title="Fraction of positives (± std)", + ) + + fig.update_xaxes( + range=[-8, 8], + constrain="domain", + tickmode="array", + tickvals=list(range(-10, 10)), + ticktext=[ + f"{2**x} : 1" if x > 0 else f"1 : {2**-x}" for x in list(range(-10, 10)) + ], + ) + + fig.update_yaxes( + range=[-8, 8], + constrain="domain", + scaleanchor="x", + scaleratio=1, + tickvals=list(range(-10, 10)), + ticktext=[ + f"{2**x} : 1" if x > 0 else f"1 : {2**-x}" for x in list(range(-10, 10)) + ], + ) + + return fig diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..5e57ff5 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,24 @@ +lxml==4.6.1 +lockfile==0.12.2 +numpy==1.19.3 +keyring==21.5.0 +mypy_extensions==0.4.3 +pandas==1.1.3 +typing_extensions==3.7.4.3 +aiohttp==3.7.4.post0 +aioitertools==0.7.1 +beautifulsoup4==4.9.3 +brotli==1.0.9 +cryptography==3.4.7 +Cython==0.29.23 +docutils==0.17.1 +importlib_metadata==4.4.0 +ipaddr==2.2.0 +ordereddict==1.1 +plotly==4.14.3 +protobuf==3.17.1 +pyOpenSSL==20.0.1 +streamlit==0.82.0 +uncurl==0.0.11 +wincertstore==0.2 +zipp==3.4.1 diff --git a/strmlt.py b/strmlt.py new file mode 100644 index 0000000..7c7320d --- /dev/null +++ b/strmlt.py @@ -0,0 +1,100 @@ +import numpy as np +import pandas as pd +import streamlit as st +import uncurl +from calibration import overconfidence +from firebase_requests import get_forecasts, get_resolutions +from gjo_requests import get_resolved_questions +from plotting import plotly_calibration, plotly_calibration_odds + +if __name__ == "__main__": + + st.title("Learn how calibrated are you?") + + # --- + + # if st.checkbox('I am new! Show me instructions.'): + # st.write(""" + # Hey! + # """) + + # --- + + platform = st.selectbox( + "Which platform are you using?", + ["Good Judgement Open", "CSET Foretell"], + ) + platform_url = { + "Good Judgement Open": "https://www.gjopen.com", + "CSET Foretell": "https://www.cset-foretell.com", + }[platform] + + uid = st.number_input("What is your user ID?", min_value=1, value=28899) + uid = str(uid) + + curl_value = "" + + curl_command = st.text_area( + "Ugh... Gimme your cURL info...", value=curl_value.strip() + ) + curl_content = uncurl.parse_context(curl_command) + headers, cookies = curl_content.headers, curl_content.cookies + + # --- + + questions = get_resolved_questions(uid, platform_url, headers, cookies) + + st.write(f"{len(questions)} questions you forecasted on have resolved.") + + # --- + # TODO: Make a progress bar..? + + forecasts = get_forecasts(uid, questions, platform_url, headers, cookies) + resolutions = get_resolutions(questions, platform_url, headers, cookies) + + # --- + + num_forecasts = sum(len(f) for f in forecasts.values()) + st.write( + f"On these {len(questions)} questions you've made {num_forecasts} forecasts." + ) + + flatten = lambda t: [item for sublist in t for item in sublist] + y_true = flatten(resolutions[q]["y_true"] for q in questions for _ in forecasts[q]) + y_pred = flatten(f["y_pred"] for q in questions for f in forecasts[q]) + + # Note that I am "double counting" each prediction. + if st.checkbox("Drop last"): + y_true = flatten( + resolutions[q]["y_true"][:-1] for q in questions for _ in forecasts[q] + ) + y_pred = flatten(f["y_pred"][:-1] for q in questions for f in forecasts[q]) + + y_true, y_pred = np.array(y_true), np.array(y_pred) + + st.write(f"Which gives us {len(y_pred)} datapoints to work with.") + + # --- + + strategy = st.selectbox( + "Which binning stranegy do you prefer?", + ["uniform", "quantile"], + ) + + recommended_n_bins = int(np.sqrt(len(y_pred))) if strategy == "quantile" else 20 + 1 + n_bins = st.number_input( + "How many bins do you want me to display?", + min_value=1, + value=recommended_n_bins, + ) + + fig = plotly_calibration(y_true, y_pred, n_bins=n_bins, strategy=strategy) + st.plotly_chart(fig, use_container_width=True) + + overconf = overconfidence(y_true, y_pred) + st.write(f"Your over/under- confidence score is {overconf:.2f}.") + + # --- + + fig = plotly_calibration_odds(y_true, y_pred, n_bins=n_bins, strategy=strategy) + st.plotly_chart(fig, use_container_width=True)