Skip to content

Commit

Permalink
bug fix for rate tables
Browse files Browse the repository at this point in the history
  • Loading branch information
amynickolls committed Jun 28, 2024
1 parent b0f3a12 commit 6423528
Showing 1 changed file with 18 additions and 1 deletion.
19 changes: 18 additions & 1 deletion ssda903/multinomial.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import numpy as np
import pandas as pd
from demand_model.base import BaseModelPredictor, combine_rates
from demand_model.base import BaseModelPredictor
from demand_model.multinomial.utils import (
build_transition_rates_matrix,
fill_missing_states,
Expand All @@ -31,6 +31,23 @@ class NextPrediction:
variance: np.ndarray


def combine_rates(rate1: pd.Series, rate2: pd.Series) -> pd.Series:
"""'
This has been updated to a multiplication method.
Fill value=1 allows any values missing from the adjustment series to remain unchanged in the transition rates
Any rates not present in rate1 will not be included in output
"""
rate1, rate2 = rate1.align(rate2, fill_value=1)

# Create a mask to identify where rate1 is missing but rate2 is not
mask = (rate1 != 1) | (rate2 == 1)

# Apply the mask to exclude undesired cases
rates = (rate1 * rate2)[mask]
rates.index.names = ["from", "to"]
return rates


class MultinomialPredictor(BaseModelPredictor):
def __init__(
self,
Expand Down

0 comments on commit 6423528

Please sign in to comment.