class aftercovid.models.EpidemicRegressor(model='SIR', t=0, max_iter=100, learning_rate_init=0.1, lr_schedule='constant', momentum=0.9, power_t=0.5, early_th=None, min_threshold='auto', max_threshold='auto', verbose=False, init=None)[source]

Follows scikit-learn API. Trains a model on observed data from an epidemic.

  • model – model to train, ‘SIR’ or ‘SIRD’ refers to CovidSIRD <aftercovid.models.CovidSIRD>, SIRDc refers to CovidSIRDc <aftercovid.models.CovidSIRDc>

  • t – implicit feature

  • max_iter – number of iteration

  • learning_rate_init – see SGDOptimizer

  • lr_schedule – see SGDOptimizer

  • momentum – see SGDOptimizer

  • power_t – see SGDOptimizer

  • early_th – see SGDOptimizer

  • verbose – see SGDOptimizer

  • min_threshold – see SGDOptimizer, if ‘auto’, the value depends on the models, if is 0.01 for model SIR, it means every coefficient must be greater than 0.01.

  • max_threshold – see SGDOptimizer, upper bound

  • init – dictionary, initializes the model with this parameters

Once trained the model holds a member model_ which contains the trained model and iter_ which holds the number of training iteration. It also keep track of the coefficients in a dictionary in attribute coef_.

fit(X, y)[source]

Trains a model to approximate its derivative as much as possible.


Predicts the derivatives.

predict_many(X, n=7)[source]

Predicts the derivatives and the series for many days.

  • X – series

  • n – number of days


derivates and series, return shape is (X.shape[0], number of parameters, n)

score(X, y, norm=None)[source]

Scores the prediction of the derivatives.

  • X – data

  • y – expected derivatives

  • norm – norm to return the norm used to optimize (L2) or ‘L1’ to return the L1 norm



simulate(X, n=7)[source]

Predicts and simulates the epidemics. Every row of X is a starting point, the function then simulates the epidemics for the next n days for every starting point.

  • X – data

  • n – number of days


quantities, matrix of shape (X.shape[0], n, number of parameters)