A three-period consumption-savings model with two regimes:
Working life (ages 25 and 45): The agent chooses whether to work and how much to consume. A simple tax-and-transfer system guarantees a consumption floor. Savings earn interest.
Retirement (age 65): Terminal regime. The agent consumes out of remaining wealth.
Model¶
An agent lives for three periods (ages 25, 45, and 65). In the first two periods (working life), the agent chooses whether to work and how much to consume . In the final period (retirement), the agent consumes out of remaining wealth.
Working life (ages 25 and 45):
subject to
where is wealth, earnings, the wage, a consumption floor guaranteed by transfers, the tax rate, and end-of-period wealth. The transfer only kicks in when the agent’s resources () fall below the consumption floor.
Retirement (age 65, terminal):
from pprint import pprint
import jax.numpy as jnp
import pandas as pd
import plotly.express as px
from lcm import (
AgeGrid,
DiscreteGrid,
LinSpacedGrid,
LogSpacedGrid,
Model,
Regime,
categorical,
)
from lcm.typing import (
BoolND,
ContinuousAction,
ContinuousState,
DiscreteAction,
FloatND,
ScalarInt,
)Categorical Variables¶
@categorical
class Work:
no: int
yes: int
@categorical
class RegimeId:
working_life: int
retirement: intModel Functions¶
# Utility
def utility(
consumption: ContinuousAction,
work: DiscreteAction,
disutility_of_work: float,
risk_aversion: float,
) -> FloatND:
return consumption ** (1 - risk_aversion) / (
1 - risk_aversion
) - disutility_of_work * (work == Work.yes)
def utility_retirement(wealth: ContinuousState, risk_aversion: float) -> FloatND:
return wealth ** (1 - risk_aversion) / (1 - risk_aversion)
# Auxiliary functions
def earnings(work: DiscreteAction, wage: float) -> FloatND:
return jnp.where(work == Work.yes, wage, 0.0)
def taxes_transfers(
earnings: FloatND,
wealth: ContinuousState,
consumption_floor: float,
tax_rate: float,
) -> FloatND:
return jnp.where(
earnings >= consumption_floor,
tax_rate * (earnings - consumption_floor),
jnp.minimum(0.0, wealth + earnings - consumption_floor),
)
def end_of_period_wealth(
wealth: ContinuousState,
earnings: FloatND,
taxes_transfers: FloatND,
consumption: ContinuousAction,
) -> FloatND:
return wealth + earnings - taxes_transfers - consumption
# State transition
def next_wealth(end_of_period_wealth: FloatND, interest_rate: float) -> ContinuousState:
return (1 + interest_rate) * end_of_period_wealth
# Constraints
def borrowing_constraint_working(end_of_period_wealth: FloatND) -> BoolND:
return end_of_period_wealth >= 0
# Regime transition
def next_regime(age: float, last_working_age: float) -> ScalarInt:
return jnp.where(
age >= last_working_age, RegimeId.retirement, RegimeId.working_life
)Regimes and Model¶
age_grid = AgeGrid(start=25, stop=65, step="20Y")
retirement_age = age_grid.precise_values[-1]
working_life = Regime(
transition=next_regime,
active=lambda age: age < retirement_age,
states={
"wealth": LinSpacedGrid(start=0, stop=50, n_points=25, transition=next_wealth),
},
actions={
"work": DiscreteGrid(Work),
"consumption": LogSpacedGrid(start=4, stop=50, n_points=100),
},
functions={
"utility": utility,
"earnings": earnings,
"taxes_transfers": taxes_transfers,
"end_of_period_wealth": end_of_period_wealth,
},
constraints={
"borrowing_constraint_working": borrowing_constraint_working,
},
)
retirement = Regime(
transition=None,
active=lambda age: age >= retirement_age,
states={
"wealth": LinSpacedGrid(start=0, stop=50, n_points=25, transition=None),
},
functions={"utility": utility_retirement},
)
model = Model(
regimes={
"working_life": working_life,
"retirement": retirement,
},
ages=age_grid,
regime_id_class=RegimeId,
description="A tiny three-period consumption-savings model.",
)Parameters¶
Use model.params_template to see what parameters the model expects, organized
by regime and function.
pprint(dict(model.params_template)){'retirement': mappingproxy({'next_wealth': mappingproxy({}),
'utility': mappingproxy({'risk_aversion': <class 'float'>})}),
'working_life': mappingproxy({'H': mappingproxy({'discount_factor': <class 'float'>}),
'borrowing_constraint_working': mappingproxy({}),
'earnings': mappingproxy({'wage': <class 'float'>}),
'end_of_period_wealth': mappingproxy({}),
'next_regime': mappingproxy({'last_working_age': <class 'float'>}),
'next_wealth': mappingproxy({'interest_rate': <class 'float'>}),
'taxes_transfers': mappingproxy({'consumption_floor': <class 'float'>,
'tax_rate': <class 'float'>}),
'utility': mappingproxy({'disutility_of_work': <class 'float'>,
'risk_aversion': <class 'float'>})})}
Parameters shared across regimes (risk_aversion, discount_factor,
interest_rate) can be specified at the model level. Parameters unique to one
regime go under the regime name.
params = {
"discount_factor": 0.95,
"risk_aversion": 1.5,
"interest_rate": 0.03,
"working_life": {
"utility": {"disutility_of_work": 1.0},
"earnings": {"wage": 20.0},
"taxes_transfers": {"consumption_floor": 2.0, "tax_rate": 0.2},
"next_regime": {"last_working_age": age_grid.precise_values[-2]},
},
}Solve and Simulate¶
n_agents = 100
result = model.solve_and_simulate(
params=params,
initial_regimes=["working_life"] * n_agents,
initial_states={"wealth": jnp.linspace(1, 40, n_agents)},
)---------------------------------------------------------------------------
InvalidInitialConditionsError Traceback (most recent call last)
Cell In[7], line 3
1 n_agents = 100
----> 3 result = model.solve_and_simulate(
4 params=params,
5 initial_regimes=["working_life"] * n_agents,
6 initial_states={"wealth": jnp.linspace(1, 40, n_agents)},
7 )
File ~/checkouts/readthedocs.org/user_builds/pylcm/checkouts/258/src/lcm/model.py:266, in Model.solve_and_simulate(self, params, initial_states, initial_regimes, check_initial_conditions, seed, debug_mode)
262 internal_params = process_params(
263 params=params, params_template=self.params_template
264 )
265 if check_initial_conditions:
--> 266 validate_initial_conditions(
267 initial_states=initial_states,
268 initial_regimes=initial_regimes,
269 internal_regimes=self.internal_regimes,
270 internal_params=internal_params,
271 ages=self.ages,
272 )
273 V_arr_dict = solve(
274 internal_params=internal_params,
275 ages=self.ages,
276 internal_regimes=self.internal_regimes,
277 logger=get_logger(debug_mode=debug_mode),
278 )
279 return simulate(
280 internal_params=internal_params,
281 initial_states=initial_states,
(...) 288 seed=seed,
289 )
File ~/checkouts/readthedocs.org/user_builds/pylcm/checkouts/258/src/lcm/simulation/validation.py:61, in validate_initial_conditions(initial_states, initial_regimes, internal_regimes, internal_params, ages)
54 structural_errors = _collect_structural_errors(
55 initial_states=initial_states,
56 initial_regimes=initial_regimes,
57 internal_regimes=internal_regimes,
58 ages=ages,
59 )
60 if structural_errors:
---> 61 raise InvalidInitialConditionsError(format_messages(structural_errors))
63 # Validate discrete state values
64 _validate_discrete_state_values(
65 initial_states=initial_states, internal_regimes=internal_regimes
66 )
InvalidInitialConditionsError: 'age' must be provided in initial_states so the validation knows each subject's starting age. Example: initial_states={'age': jnp.array([25.0, 25.0]), ...} Required initial states are: ['age', 'wealth']df = result.to_dataframe(additional_targets="all")
df["age"] = df["age"].astype(int)
df.loc[df["age"] == retirement_age, "consumption"] = df.loc[
df["age"] == retirement_age, "wealth"
]
columns = [
"regime",
"work",
"consumption",
"wealth",
"earnings",
"taxes_transfers",
"end_of_period_wealth",
"value",
]
df.set_index(["subject_id", "age"])[columns].head(20).style.format(
precision=1,
na_rep="",
)# Classify agents by work pattern across the two working-life periods
first_working_age = age_grid.precise_values[0]
last_working_age = age_grid.precise_values[-2]
df_working = df[df["regime"] == "working_life"]
work_by_age = df_working.pivot_table(
index="subject_id",
columns="age",
values="work",
aggfunc="first",
)
work_pattern = (
work_by_age[first_working_age].astype(str)
+ ", "
+ work_by_age[last_working_age].astype(str)
)
assert "yes, yes" not in work_pattern.to_numpy(), (
"Plotting assumes that no agent works in both periods of working life."
)
label_map = {
"yes, no": "low", # work early, not later
"no, yes": "medium", # coast early, work later
"no, no": "high", # never work
}
groups = work_pattern.map(label_map).rename("initial_wealth")
# Combined descriptives and work decisions table
initial_wealth = df[df["age"] == first_working_age].set_index("subject_id")["wealth"]
group_desc = initial_wealth.groupby(groups).agg(["min", "max"]).round(1)
df_groups = df.copy()
df_groups["initial_wealth"] = df_groups["subject_id"].map(groups)
df_mean = df_groups.groupby(["initial_wealth", "age"], as_index=False).mean(
numeric_only=True,
)
work_table = df_mean[df_mean["age"] < retirement_age].pivot_table(
index="initial_wealth",
columns="age",
values="earnings",
)
work_table = (work_table > 0).astype(int)
work_table.columns = [f"works {c}" for c in work_table.columns]
summary = pd.concat([group_desc, work_table], axis=1)
summary.index.name = "initial_wealth"
summary.loc[["low", "medium", "high"]].style.format(precision=1, na_rep="")fig = px.line(
df_mean,
x="age",
y="consumption",
color="initial_wealth",
title="Consumption by Age",
template="plotly_dark",
)
fig.show()fig = px.line(
df_mean,
x="age",
y="wealth",
color="initial_wealth",
title="Wealth by Age",
template="plotly_dark",
)
fig.show()