diff --git a/strmlt.py b/strmlt.py index 26bc6d7..a3851db 100644 --- a/strmlt.py +++ b/strmlt.py @@ -118,9 +118,11 @@ if __name__ == "__main__": ) y_pred = flatten(f["y_pred"][:-1] for q in questions for f in forecasts[q]) - y_true, y_pred = np.array(y_true), np.array(y_pred) + y_true, y_pred = np.array(y_true), np.array(y_pred) - st.write(f"Which gives us {len(y_pred)} datapoints to work with.") + order = np.arange(len(y_true)) + np.random.default_rng(0).shuffle(order) + y_true, y_pred = y_true[order], y_pred[order] # ---