diff --git a/dm_regional_app/charts.py b/dm_regional_app/charts.py
index 91b20f8..0c0d1a6 100644
--- a/dm_regional_app/charts.py
+++ b/dm_regional_app/charts.py
@@ -10,12 +10,15 @@ def prediction_chart(historic_data: PopulationStats, prediction: Prediction, **k
# pop start and end dates to visualise reference period
reference_start_date = kwargs.pop("reference_start_date")
reference_end_date = kwargs.pop("reference_end_date")
-
+ print(prediction.population)
# dataframe containing total children in prediction
df = prediction.population.unstack().reset_index()
+ print(df)
df.columns = ["from", "date", "forecast"]
- df = df[df["from"].apply(lambda x: "NOT_IN_CARE" in x[1]) == False]
+ print(df)
+ df = df[df["from"].apply(lambda x: "Not in care" in x) == False]
+ print(df)
df = df[["date", "forecast"]].groupby(by="date").sum().reset_index()
df["date"] = pd.to_datetime(df["date"]).dt.date
@@ -99,7 +102,7 @@ def prediction_chart(historic_data: PopulationStats, prediction: Prediction, **k
)
fig.update_layout(
- title="Base forecast", xaxis_title="Date", yaxis_title="Number of children"
+ title="Forecast", xaxis_title="Date", yaxis_title="Number of children"
)
fig.update_yaxes(rangemode="tozero")
fig_html = fig.to_html(full_html=False)
@@ -134,7 +137,7 @@ def transition_rate_table(data):
df["To"] = df["to"]
df["From"] = df["from"]
df.set_index(["from", "to"], inplace=True)
- df = df[df["To"].apply(lambda x: "NOT_IN_CARE" in x[1]) == False]
+ # df = df[df["To"].apply(lambda x: "Not in care" in x) == False]
df = df.round(4)
df["From"] = df["From"].mask(df["From"].duplicated(), "")
@@ -143,4 +146,6 @@ def transition_rate_table(data):
from_col = df.pop("From")
df.insert(0, "From", from_col)
+ df.columns = ["From", "To", "Transition rate"]
+
return df
diff --git a/dm_regional_app/templatetags/table_tags.py b/dm_regional_app/templatetags/table_tags.py
index 0dcdf2c..e9ebb6a 100644
--- a/dm_regional_app/templatetags/table_tags.py
+++ b/dm_regional_app/templatetags/table_tags.py
@@ -6,7 +6,7 @@
def convert_data_frame_to_html_table_headers(df):
html = "
"
for value in df.columns:
- html += f'{value.capitalize()} | '
+ html += f'{value} | '
html += "
"
return html
@@ -42,9 +42,7 @@ def convert_data_frame_to_html_table_rows(df):
row_html = ""
for value in row:
if isinstance(value, str):
- row_html += (
- f'{value.capitalize()} | '
- )
+ row_html += f'{value} | '
else:
row_html += f'{value} | '
row_html += "
"
diff --git a/dm_regional_app/utils.py b/dm_regional_app/utils.py
index 8f0c49f..d91a36f 100644
--- a/dm_regional_app/utils.py
+++ b/dm_regional_app/utils.py
@@ -28,38 +28,42 @@ def apply_filters(data: pd.DataFrame, filters: dict):
class DateAwareJSONDecoder(json.JSONDecoder):
def __init__(self, *args, **kwargs):
- super().__init__(object_hook=self.parse_dates, *args, **kwargs)
-
- def parse_dates(self, obj):
- for key, value in obj.items():
- if isinstance(value, str) and "date" in key:
- try:
- obj[key] = datetime.fromisoformat(value).date()
- except ValueError:
- pass
- return obj
+ super().__init__(object_hook=self.parse_object, *args, **kwargs)
def parse_object(self, obj):
obj = self.parse_dates(obj)
if "__type__" in obj and obj["__type__"] == "pd.Series":
- if all(isinstance(i, list) for i in obj["index"]):
+ if obj.get("is_multiindex", False):
index = pd.MultiIndex.from_tuples(obj["index"])
else:
index = obj["index"]
return pd.Series(obj["data"], index=index)
return obj
+ def parse_dates(self, obj):
+ for key, value in obj.items():
+ if isinstance(value, str) and "date" in key:
+ try:
+ obj[key] = datetime.fromisoformat(value).date()
+ except ValueError:
+ pass
+ return obj
+
class SeriesAwareJSONEncoder(json.JSONEncoder):
def default(self, obj):
- # Check if the Series has a MultiIndex
if isinstance(obj, pd.Series):
# Check if the Series has a MultiIndex
if isinstance(obj.index, pd.MultiIndex):
- index = obj.index.tolist()
+ index = [list(tup) for tup in obj.index]
else:
index = obj.index.tolist()
- return {"__type__": "pd.Series", "data": obj.tolist(), "index": index}
+ return {
+ "__type__": "pd.Series",
+ "data": obj.tolist(),
+ "index": index,
+ "is_multiindex": isinstance(obj.index, pd.MultiIndex),
+ }
if isinstance(obj, date):
return obj.isoformat()
# Let the base class default method raise the TypeError
diff --git a/dm_regional_app/views.py b/dm_regional_app/views.py
index f228bde..60a2ea8 100644
--- a/dm_regional_app/views.py
+++ b/dm_regional_app/views.py
@@ -233,7 +233,9 @@ def adjusted(request):
# Call predict function with default dates
prediction = predict(
- data=historic_data, **session_scenario.prediction_parameters
+ data=historic_data,
+ **session_scenario.prediction_parameters,
+ rate_adjustment=session_scenario.adjusted_rates
)
# build chart
diff --git a/ssda903/predictor.py b/ssda903/predictor.py
index 58ddfde..a4ebaf8 100644
--- a/ssda903/predictor.py
+++ b/ssda903/predictor.py
@@ -1,5 +1,5 @@
from datetime import date
-from typing import Optional
+from typing import Iterable, Optional, Union
import pandas as pd
from dateutil.relativedelta import relativedelta
@@ -15,6 +15,7 @@ def predict(
reference_end_date: date,
prediction_start_date: Optional[date] = None,
prediction_end_date: Optional[date] = None,
+ rate_adjustment: Union[pd.Series, Iterable[pd.Series]] = None,
) -> Prediction:
"""
Analyses source between start and end, and then predicts the population at prediction_date.
@@ -39,6 +40,7 @@ def predict(
reference_start_date, reference_end_date
),
start_date=prediction_start_date,
+ rate_adjustment=rate_adjustment,
)
prediction_days = (prediction_end_date - prediction_start_date).days
prediction = predictor.predict(prediction_days, progress=False)