Skip to article frontmatterSkip to article content
Site not loading correctly?

This may be due to an incorrect BASE_URL configuration. See the MyST Documentation for reference.

A Tiny Example

A three-period consumption-savings model with two regimes:

  • Working life (ages 25 and 45): The agent chooses whether to work and how much to consume. A simple tax-and-transfer system guarantees a consumption floor. Savings earn interest.

  • Retirement (age 65): Terminal regime. The agent consumes out of remaining wealth.

Model

An agent lives for three periods (ages 25, 45, and 65). In the first two periods (working life), the agent chooses whether to work dt{0,1}d_t \in \{0, 1\} and how much to consume ctc_t. In the final period (retirement), the agent consumes out of remaining wealth.

Working life (ages 25 and 45):

Vt(wt)=maxdt,ct{ct1σ1σϕdt+βVt+1(wt+1)}V_t(w_t) = \max_{d_t,\, c_t} \left\{ \frac{c_t^{1-\sigma}}{1-\sigma} - \phi \, d_t + \beta \, V_{t+1}(w_{t+1}) \right\}

subject to

et=dtwˉτ(et,wt)={θ(etc)if etcmin(0,  wt+etc)otherwiseat=wt+etτ(et,wt)ctwt+1=(1+r)atat0\begin{align} e_t &= d_t \cdot \bar{w} \\[4pt] \tau(e_t, w_t) &= \begin{cases} \theta\,(e_t - \underline{c}) & \text{if } e_t \geq \underline{c} \\ \min(0,\; w_t + e_t - \underline{c}) & \text{otherwise} \end{cases} \\[4pt] a_t &= w_t + e_t - \tau(e_t, w_t) - c_t \\[4pt] w_{t+1} &= (1 + r)\, a_t \\[4pt] a_t &\geq 0 \end{align}

where wtw_t is wealth, ete_t earnings, wˉ\bar{w} the wage, c\underline{c} a consumption floor guaranteed by transfers, θ\theta the tax rate, and ata_t end-of-period wealth. The transfer only kicks in when the agent’s resources (wt+etw_t + e_t) fall below the consumption floor.

Retirement (age 65, terminal):

V2(w2)=maxc2w2c21σ1σV_2(w_2) = \max_{c_2 \leq w_2} \frac{c_2^{1-\sigma}}{1-\sigma}
from pprint import pprint

import jax.numpy as jnp
import pandas as pd
import plotly.express as px

from lcm import (
    AgeGrid,
    DiscreteGrid,
    LinSpacedGrid,
    LogSpacedGrid,
    Model,
    Regime,
    categorical,
)
from lcm.typing import (
    BoolND,
    ContinuousAction,
    ContinuousState,
    DiscreteAction,
    FloatND,
    ScalarInt,
)

Categorical Variables

@categorical
class Work:
    no: int
    yes: int


@categorical
class RegimeId:
    working_life: int
    retirement: int

Model Functions

# Utility


def utility(
    consumption: ContinuousAction,
    work: DiscreteAction,
    disutility_of_work: float,
    risk_aversion: float,
) -> FloatND:
    return consumption ** (1 - risk_aversion) / (
        1 - risk_aversion
    ) - disutility_of_work * (work == Work.yes)


def utility_retirement(wealth: ContinuousState, risk_aversion: float) -> FloatND:
    return wealth ** (1 - risk_aversion) / (1 - risk_aversion)


# Auxiliary functions


def earnings(work: DiscreteAction, wage: float) -> FloatND:
    return jnp.where(work == Work.yes, wage, 0.0)


def taxes_transfers(
    earnings: FloatND,
    wealth: ContinuousState,
    consumption_floor: float,
    tax_rate: float,
) -> FloatND:
    return jnp.where(
        earnings >= consumption_floor,
        tax_rate * (earnings - consumption_floor),
        jnp.minimum(0.0, wealth + earnings - consumption_floor),
    )


def end_of_period_wealth(
    wealth: ContinuousState,
    earnings: FloatND,
    taxes_transfers: FloatND,
    consumption: ContinuousAction,
) -> FloatND:
    return wealth + earnings - taxes_transfers - consumption


# State transition


def next_wealth(end_of_period_wealth: FloatND, interest_rate: float) -> ContinuousState:
    return (1 + interest_rate) * end_of_period_wealth


# Constraints


def borrowing_constraint_working(end_of_period_wealth: FloatND) -> BoolND:
    return end_of_period_wealth >= 0


# Regime transition


def next_regime(age: float, last_working_age: float) -> ScalarInt:
    return jnp.where(
        age >= last_working_age, RegimeId.retirement, RegimeId.working_life
    )

Regimes and Model

age_grid = AgeGrid(start=25, stop=65, step="20Y")
retirement_age = age_grid.precise_values[-1]

working_life = Regime(
    transition=next_regime,
    active=lambda age: age < retirement_age,
    states={
        "wealth": LinSpacedGrid(start=0, stop=50, n_points=25, transition=next_wealth),
    },
    actions={
        "work": DiscreteGrid(Work),
        "consumption": LogSpacedGrid(start=4, stop=50, n_points=100),
    },
    functions={
        "utility": utility,
        "earnings": earnings,
        "taxes_transfers": taxes_transfers,
        "end_of_period_wealth": end_of_period_wealth,
    },
    constraints={
        "borrowing_constraint_working": borrowing_constraint_working,
    },
)

retirement = Regime(
    transition=None,
    active=lambda age: age >= retirement_age,
    states={
        "wealth": LinSpacedGrid(start=0, stop=50, n_points=25, transition=None),
    },
    functions={"utility": utility_retirement},
)

model = Model(
    regimes={
        "working_life": working_life,
        "retirement": retirement,
    },
    ages=age_grid,
    regime_id_class=RegimeId,
    description="A tiny three-period consumption-savings model.",
)

Parameters

Use model.params_template to see what parameters the model expects, organized by regime and function.

pprint(dict(model.params_template))
{'retirement': mappingproxy({'next_wealth': mappingproxy({}),
                             'utility': mappingproxy({'risk_aversion': <class 'float'>})}),
 'working_life': mappingproxy({'H': mappingproxy({'discount_factor': <class 'float'>}),
                               'borrowing_constraint_working': mappingproxy({}),
                               'earnings': mappingproxy({'wage': <class 'float'>}),
                               'end_of_period_wealth': mappingproxy({}),
                               'next_regime': mappingproxy({'last_working_age': <class 'float'>}),
                               'next_wealth': mappingproxy({'interest_rate': <class 'float'>}),
                               'taxes_transfers': mappingproxy({'consumption_floor': <class 'float'>,
                                                                'tax_rate': <class 'float'>}),
                               'utility': mappingproxy({'disutility_of_work': <class 'float'>,
                                                        'risk_aversion': <class 'float'>})})}

Parameters shared across regimes (risk_aversion, discount_factor, interest_rate) can be specified at the model level. Parameters unique to one regime go under the regime name.

params = {
    "discount_factor": 0.95,
    "risk_aversion": 1.5,
    "interest_rate": 0.03,
    "working_life": {
        "utility": {"disutility_of_work": 1.0},
        "earnings": {"wage": 20.0},
        "taxes_transfers": {"consumption_floor": 2.0, "tax_rate": 0.2},
        "next_regime": {"last_working_age": age_grid.precise_values[-2]},
    },
}

Solve and Simulate

n_agents = 100

result = model.solve_and_simulate(
    params=params,
    initial_regimes=["working_life"] * n_agents,
    initial_states={"wealth": jnp.linspace(1, 40, n_agents)},
)
---------------------------------------------------------------------------
InvalidInitialConditionsError             Traceback (most recent call last)
Cell In[7], line 3
      1 n_agents = 100
----> 3 result = model.solve_and_simulate(
      4     params=params,
      5     initial_regimes=["working_life"] * n_agents,
      6     initial_states={"wealth": jnp.linspace(1, 40, n_agents)},
      7 )

File ~/checkouts/readthedocs.org/user_builds/pylcm/checkouts/258/src/lcm/model.py:266, in Model.solve_and_simulate(self, params, initial_states, initial_regimes, check_initial_conditions, seed, debug_mode)
    262 internal_params = process_params(
    263     params=params, params_template=self.params_template
    264 )
    265 if check_initial_conditions:
--> 266     validate_initial_conditions(
    267         initial_states=initial_states,
    268         initial_regimes=initial_regimes,
    269         internal_regimes=self.internal_regimes,
    270         internal_params=internal_params,
    271         ages=self.ages,
    272     )
    273 V_arr_dict = solve(
    274     internal_params=internal_params,
    275     ages=self.ages,
    276     internal_regimes=self.internal_regimes,
    277     logger=get_logger(debug_mode=debug_mode),
    278 )
    279 return simulate(
    280     internal_params=internal_params,
    281     initial_states=initial_states,
   (...)    288     seed=seed,
    289 )

File ~/checkouts/readthedocs.org/user_builds/pylcm/checkouts/258/src/lcm/simulation/validation.py:61, in validate_initial_conditions(initial_states, initial_regimes, internal_regimes, internal_params, ages)
     54 structural_errors = _collect_structural_errors(
     55     initial_states=initial_states,
     56     initial_regimes=initial_regimes,
     57     internal_regimes=internal_regimes,
     58     ages=ages,
     59 )
     60 if structural_errors:
---> 61     raise InvalidInitialConditionsError(format_messages(structural_errors))
     63 # Validate discrete state values
     64 _validate_discrete_state_values(
     65     initial_states=initial_states, internal_regimes=internal_regimes
     66 )

InvalidInitialConditionsError: 'age' must be provided in initial_states so the validation knows each subject's starting age. Example: initial_states={'age': jnp.array([25.0, 25.0]), ...} Required initial states are: ['age', 'wealth']
df = result.to_dataframe(additional_targets="all")
df["age"] = df["age"].astype(int)
df.loc[df["age"] == retirement_age, "consumption"] = df.loc[
    df["age"] == retirement_age, "wealth"
]
columns = [
    "regime",
    "work",
    "consumption",
    "wealth",
    "earnings",
    "taxes_transfers",
    "end_of_period_wealth",
    "value",
]
df.set_index(["subject_id", "age"])[columns].head(20).style.format(
    precision=1,
    na_rep="",
)
# Classify agents by work pattern across the two working-life periods
first_working_age = age_grid.precise_values[0]
last_working_age = age_grid.precise_values[-2]

df_working = df[df["regime"] == "working_life"]
work_by_age = df_working.pivot_table(
    index="subject_id",
    columns="age",
    values="work",
    aggfunc="first",
)
work_pattern = (
    work_by_age[first_working_age].astype(str)
    + ", "
    + work_by_age[last_working_age].astype(str)
)
assert "yes, yes" not in work_pattern.to_numpy(), (
    "Plotting assumes that no agent works in both periods of working life."
)

label_map = {
    "yes, no": "low",  # work early, not later
    "no, yes": "medium",  # coast early, work later
    "no, no": "high",  # never work
}
groups = work_pattern.map(label_map).rename("initial_wealth")

# Combined descriptives and work decisions table
initial_wealth = df[df["age"] == first_working_age].set_index("subject_id")["wealth"]
group_desc = initial_wealth.groupby(groups).agg(["min", "max"]).round(1)

df_groups = df.copy()
df_groups["initial_wealth"] = df_groups["subject_id"].map(groups)
df_mean = df_groups.groupby(["initial_wealth", "age"], as_index=False).mean(
    numeric_only=True,
)
work_table = df_mean[df_mean["age"] < retirement_age].pivot_table(
    index="initial_wealth",
    columns="age",
    values="earnings",
)
work_table = (work_table > 0).astype(int)
work_table.columns = [f"works {c}" for c in work_table.columns]

summary = pd.concat([group_desc, work_table], axis=1)
summary.index.name = "initial_wealth"
summary.loc[["low", "medium", "high"]].style.format(precision=1, na_rep="")
fig = px.line(
    df_mean,
    x="age",
    y="consumption",
    color="initial_wealth",
    title="Consumption by Age",
    template="plotly_dark",
)
fig.show()
fig = px.line(
    df_mean,
    x="age",
    y="wealth",
    color="initial_wealth",
    title="Wealth by Age",
    template="plotly_dark",
)
fig.show()