diff --git a/dm_regional_app/charts.py b/dm_regional_app/charts.py index 91b20f8..0c0d1a6 100644 --- a/dm_regional_app/charts.py +++ b/dm_regional_app/charts.py @@ -10,12 +10,15 @@ def prediction_chart(historic_data: PopulationStats, prediction: Prediction, **k # pop start and end dates to visualise reference period reference_start_date = kwargs.pop("reference_start_date") reference_end_date = kwargs.pop("reference_end_date") - + print(prediction.population) # dataframe containing total children in prediction df = prediction.population.unstack().reset_index() + print(df) df.columns = ["from", "date", "forecast"] - df = df[df["from"].apply(lambda x: "NOT_IN_CARE" in x[1]) == False] + print(df) + df = df[df["from"].apply(lambda x: "Not in care" in x) == False] + print(df) df = df[["date", "forecast"]].groupby(by="date").sum().reset_index() df["date"] = pd.to_datetime(df["date"]).dt.date @@ -99,7 +102,7 @@ def prediction_chart(historic_data: PopulationStats, prediction: Prediction, **k ) fig.update_layout( - title="Base forecast", xaxis_title="Date", yaxis_title="Number of children" + title="Forecast", xaxis_title="Date", yaxis_title="Number of children" ) fig.update_yaxes(rangemode="tozero") fig_html = fig.to_html(full_html=False) @@ -134,7 +137,7 @@ def transition_rate_table(data): df["To"] = df["to"] df["From"] = df["from"] df.set_index(["from", "to"], inplace=True) - df = df[df["To"].apply(lambda x: "NOT_IN_CARE" in x[1]) == False] + # df = df[df["To"].apply(lambda x: "Not in care" in x) == False] df = df.round(4) df["From"] = df["From"].mask(df["From"].duplicated(), "") @@ -143,4 +146,6 @@ def transition_rate_table(data): from_col = df.pop("From") df.insert(0, "From", from_col) + df.columns = ["From", "To", "Transition rate"] + return df diff --git a/dm_regional_app/templatetags/table_tags.py b/dm_regional_app/templatetags/table_tags.py index 0dcdf2c..e9ebb6a 100644 --- a/dm_regional_app/templatetags/table_tags.py +++ b/dm_regional_app/templatetags/table_tags.py @@ -6,7 +6,7 @@ def convert_data_frame_to_html_table_headers(df): html = "" for value in df.columns: - html += f'

{value.capitalize()}

' + html += f'

{value}

' html += "" return html @@ -42,9 +42,7 @@ def convert_data_frame_to_html_table_rows(df): row_html = "" for value in row: if isinstance(value, str): - row_html += ( - f'

{value.capitalize()}

' - ) + row_html += f'

{value}

' else: row_html += f'

{value}

' row_html += "" diff --git a/dm_regional_app/utils.py b/dm_regional_app/utils.py index 8f0c49f..d91a36f 100644 --- a/dm_regional_app/utils.py +++ b/dm_regional_app/utils.py @@ -28,38 +28,42 @@ def apply_filters(data: pd.DataFrame, filters: dict): class DateAwareJSONDecoder(json.JSONDecoder): def __init__(self, *args, **kwargs): - super().__init__(object_hook=self.parse_dates, *args, **kwargs) - - def parse_dates(self, obj): - for key, value in obj.items(): - if isinstance(value, str) and "date" in key: - try: - obj[key] = datetime.fromisoformat(value).date() - except ValueError: - pass - return obj + super().__init__(object_hook=self.parse_object, *args, **kwargs) def parse_object(self, obj): obj = self.parse_dates(obj) if "__type__" in obj and obj["__type__"] == "pd.Series": - if all(isinstance(i, list) for i in obj["index"]): + if obj.get("is_multiindex", False): index = pd.MultiIndex.from_tuples(obj["index"]) else: index = obj["index"] return pd.Series(obj["data"], index=index) return obj + def parse_dates(self, obj): + for key, value in obj.items(): + if isinstance(value, str) and "date" in key: + try: + obj[key] = datetime.fromisoformat(value).date() + except ValueError: + pass + return obj + class SeriesAwareJSONEncoder(json.JSONEncoder): def default(self, obj): - # Check if the Series has a MultiIndex if isinstance(obj, pd.Series): # Check if the Series has a MultiIndex if isinstance(obj.index, pd.MultiIndex): - index = obj.index.tolist() + index = [list(tup) for tup in obj.index] else: index = obj.index.tolist() - return {"__type__": "pd.Series", "data": obj.tolist(), "index": index} + return { + "__type__": "pd.Series", + "data": obj.tolist(), + "index": index, + "is_multiindex": isinstance(obj.index, pd.MultiIndex), + } if isinstance(obj, date): return obj.isoformat() # Let the base class default method raise the TypeError diff --git a/dm_regional_app/views.py b/dm_regional_app/views.py index f228bde..60a2ea8 100644 --- a/dm_regional_app/views.py +++ b/dm_regional_app/views.py @@ -233,7 +233,9 @@ def adjusted(request): # Call predict function with default dates prediction = predict( - data=historic_data, **session_scenario.prediction_parameters + data=historic_data, + **session_scenario.prediction_parameters, + rate_adjustment=session_scenario.adjusted_rates ) # build chart diff --git a/ssda903/predictor.py b/ssda903/predictor.py index 58ddfde..a4ebaf8 100644 --- a/ssda903/predictor.py +++ b/ssda903/predictor.py @@ -1,5 +1,5 @@ from datetime import date -from typing import Optional +from typing import Iterable, Optional, Union import pandas as pd from dateutil.relativedelta import relativedelta @@ -15,6 +15,7 @@ def predict( reference_end_date: date, prediction_start_date: Optional[date] = None, prediction_end_date: Optional[date] = None, + rate_adjustment: Union[pd.Series, Iterable[pd.Series]] = None, ) -> Prediction: """ Analyses source between start and end, and then predicts the population at prediction_date. @@ -39,6 +40,7 @@ def predict( reference_start_date, reference_end_date ), start_date=prediction_start_date, + rate_adjustment=rate_adjustment, ) prediction_days = (prediction_end_date - prediction_start_date).days prediction = predictor.predict(prediction_days, progress=False)