Efficient value function iteration in Julia

Greetings:

I was loosely adapting Fortran code into Julia, which solves the standard cake eating problem using spline interpolation and numerical maximization. Whereas the original (compiled) Fortran code finishes in second, the Julia code takes many minutes.

As such, it seems I am inadvertently introducing something very inefficient. The example is based on the textbook exercises in https://www.ce-fortran.com by Hans Fehr and Fabian Kindermann, specifically Chapter 8 Exercise 5. To be clear, there are differences in the numerical maximization step. However, the parameterization, size of grid, and convergence criterion should be identical.

I attach the Julia code below.

using PyPlot
using Plots
using BenchmarkTools
using LaTeXStrings
using Parameters, QuantEcon
using  Dierckx, ArgParse
using LinearAlgebra, Optim, Interpolations
using Printf
using StatsBase
using Optim: converged, maximum, maximizer, minimizer, iterations # some extra functions
#import Base.Threads.@threads
using ProfileView

#Threads.nthreads()

@with_kw  struct Para
    # model parameters
    γ::Float64 = 0.5
    egam::Float64 = 1-1/γ
    β::Float64 = 0.95
    a0::Float64 = 100

    # numerical parameters
    σ::Float64 = 1e-6
    itermax::Int64 = 2000
    T::Int64 = 200
    NA::Int64 = 1000
    a = range(0., 100., length=NA)
end

function TV(a_prime, ia, V, para)
    # iteration on value function
    @unpack egam, β, a = para
    c = max(a[ia]-a_prime, 1e-10)
    # interpolating function for value function
    V_temp = CubicSplineInterpolation(a, (egam*V).^(1.0/egam))(a_prime)
    V_fun = max(V_temp, 1e-10)^(egam)/egam
    # calculate negative of RHS
    val = -(c^egam/egam + β*V_fun)
    return val
end

function f(ia, V, para)
    # numerical maximization
    @unpack a = para
     result = optimize(x ->TV(x, ia, V, para), 0.0, a[ia], Brent())
     return -result.minimum, result.minimizer
end

function update(a, V, para)
    c = similar(V)
    V_new = similar(V)
    @inbounds for ia in eachindex(a)
        #a_prime = a[ia] - c[ia]
        #result = optimize(x ->TV(x, ia, V, para), 0.0, a[ia], Brent())
        #result = optimize(f, [a_prime], LBFGS())
        # optimal consumption and value functions
        #V_new[ia] = -minimum(result)
        V_new[ia], a_prime = f(ia, V, para)
        c[ia] = a[ia] - a_prime
    end
    return c, V_new
end


function value_function_iter(para)
    @unpack γ, egam, β, a0, σ, itermax, T, NA, a = para
    #initial consumption guess
    #c = collect(a)./2
    c = collect(a./2)
    # initial value function based on consumption guess
    V = c.^egam/egam./(1-β)
    #V = zero(c)
    conv_level = 1.0
    iter = 1
    while (iter < itermax) && (conv_level > σ)
        #f(ia) = optimize(x ->TV(x, ia, V, para), 0.0, a[ia], Brent())
        # calculate optimal decision for every grid point 
        #Threads.@threads for ia in eachindex(a)
        c, V_new = update(a, V, para)

        # get convergence level
        crit = abs.(V_new-V)./max.(abs.(V), 1e-10)
        crit = filter(!isnan, crit)
        conv_level = maximum(crit)
        println("Iteration=$(iter); convergence level=$(conv_level)")
        # update value function and iteration count
        V = V_new
        iter += 1
    end
    if conv_level > σ
        println("Failed to converge")
    end
    return c, V
end

function main()
    para = Para(NA=1000)
    c, V = value_function_iter(para)
end
c, V = main()
# smaller sample for profiling
para = Para(NA=100)
value_function_iter($para)
ProfileView.@profview value_function_iter(para)


#@btime TV($a_prime, $ia, $V, $para)
function simulate_consumption(c; a0=100, T=200)
    # Policy function
    c_pol(x) = Spline1D(a, c)(x)
    a_t = zeros(T)
    c_t = zeros(T)
    a_t[1] = a0
    c_t[1] = c_pol(a0)
    for i in 2:T
        a_t[i] = a_t[i-1] - c_t[i-1]
        c_t[i] = c_pol(a_t[i])
    end
    return a_t, c_t
end

a_t, c_t = simulate_consumption(c)

t = range(1, T, length=T)

fig, ax = subplots(1, 1, figsize=(10, 5))
ax.plot(t, c_t, label="c")
ax.set_title("Consumption path")
ax.legend()
display(fig)
PyPlot.savefig("consumption_path")

fig, ax = subplots(nrows=1, ncols=2, figsize=(20, 5))
ax[1].plot(a, c, label="policy function c(a)")
ax[2].plot(a[2:end], V[2:end], label="value function V(a)")
ax[1].legend()
ax[2].legend()
tight_layout()
display(fig)

As an additional reference, I also reference the Fortran code:

!##############################################################################
! PROGRAM Minimize
!
! ## Value function iteration and numerical minimization
!
! This code is published under the GNU General Public License v3
!                         (https://www.gnu.org/licenses/gpl-3.0.en.html)
!
! Authors: Hans Fehr and Fabian Kindermann
!          contact@ce-fortran.com
!
!##############################################################################
include "prog08_05m.f90"

program Minimize

    use globals
    use toolbox

    implicit none
    integer :: ia, iter

    ! start timer
    call tic()

    ! initialize a, c and value function
    call grid_Cons_Equi(a, 0d0, a0)
    c(:) = a(:)/2d0
    V(:) = 0d0
    coeff_v(:) = 0d0

    ! iterate until value function converges
    do iter = 1, itermax

        ! set a = 0 manually
        c(0) = 0d0
        V_new(0) = -1d10

        ! calculate optimal decision for every gridpoint
        do ia = 1, NA

            ! initialize starting value and communicate resources
            x_in = a(ia) - c(ia)
            ia_com = ia

            call fminsearch(x_in, fret, 0d0, a(ia), utility)

            ! get optimal consumption and value function
            c(ia) = a(ia) - x_in
            V_new(ia) = -fret

        enddo

        ! interpolate coefficients
        call spline_interp((egam*V_new)**(1d0/egam), coeff_V)

        ! get convergence level
        con_lev = maxval(abs(V_new(:) - V(:))/max(abs(V(:)), 1d-10))
        write(*,'(i5,2x,f20.7)')iter, con_lev

        ! check for convergence
        if(con_lev < sig)then
            call output()
        endif

        V = V_new
    enddo

    write(*,*)'No Convergence'

contains


    ! For creating output plots.
    subroutine output()

        use toolbox

        implicit none
        integer, parameter :: n_err = 10000
        integer :: it
        real*8 :: err, a_err, err_temp

        ! end timer
        call toc()

        ! interpolate policy function
        call spline_interp(c, coeff_c)

        ! calculate the time path of consumption numerically
        a_t(0) = a0
        c_t(0) = spline_eval(a_t(0), coeff_c, 0d0, a0)
        do it = 1, TT
            a_t(it) = a_t(it-1) - c_t(it-1)
            c_t(it) = spline_eval(a_t(it), coeff_c, 0d0, a0)
        enddo
        call plot((/(dble(it),it=0,TT)/), c_t, legend='numerical')

        ! calculate the time path of consumption analytically
        a_t(0) = a0
        c_t(0) = a_t(0)*(1d0-beta**gamma)
        do it = 1, TT
            a_t(it) = a_t(it-1) - c_t(it-1)
            c_t(it) = a_t(it)*(1d0-beta**gamma)
        enddo
        call plot((/(dble(it),it=0,TT)/), c_t, legend='analytical')
        call execplot(xlabel='Time t', ylabel='Consumption c_t')

        ! plot numerical and analytical consumption
        call plot(a, c, legend='numerical')
        call plot(a, a*(1d0-beta**gamma), legend='analytical')
        call execplot(xlabel='Level of resources a', ylabel='Policy Function c(a)')

        ! plot numerical and analytical value function
        call plot(a(10:NA), V(10:NA), legend='numerical')
        call plot(a(10:NA), (1d0-beta**gamma)**(-1d0/gamma)*a(10:NA) &
                                                **egam/egam, legend='analytical')
        call execplot(xlabel='Level of Resources a', ylabel='Value Function V(a)')

        ! calculate consumption function error
        err = 0d0
        do ia = 0, n_err
            a_err = a0*dble(ia)/dble(n_err)
            err_temp = abs(spline_eval(a_err, coeff_c, 0d0, a0) - a_err*(1d0-beta**gamma)) &
                                            /max(a_err*(1d0-beta**gamma), 1d-10)
            if(err_temp > err)err = err_temp
        enddo
        write(*,'(a, es15.7)')'Consumption function error:',err

        ! quit program
        stop

    end subroutine

end program

This main file calls the following module:

!##############################################################################
! MODULE globals
!
! This code is published under the GNU General Public License v3
!                         (https://www.gnu.org/licenses/gpl-3.0.en.html)
!
! Authors: Hans Fehr and Fabian Kindermann
!          contact@ce-fortran.com
!
!##############################################################################
module globals

    implicit none

    ! model parameters
    real*8, parameter :: gamma = 0.5d0
    real*8, parameter :: egam = 1d0-1d0/gamma
    real*8, parameter :: beta = 0.95d0
    real*8, parameter :: a0 = 100d0

    ! numerical parameters
    real*8, parameter :: sig = 1d-6
    integer, parameter :: itermax = 2000

    ! time path of consumption and resource
    integer, parameter :: TT = 200
    real*8 :: c_t(0:TT), a_t(0:TT)

    ! value and policy function
    integer, parameter :: NA = 1000
    real*8 :: a(0:NA), c(0:NA), V(0:NA)

    ! variables to numerically determine value and policy function
    real*8 :: V_new(0:NA), coeff_V(NA+3), coeff_c(NA+3)
    real*8 :: con_lev, x_in, fret

    ! variables to communicate with function
    integer :: ia_com

contains


    ! the function that should be minimized
    function utility(x_in)

        use toolbox

        implicit none
        real*8, intent(in) :: x_in
        real*8 :: utility, cons, vplus

        ! calculate consumption
        cons = max(a(ia_com) - x_in, 1d-10)

        ! calculate future utility
        vplus = max(spline_eval(x_in, coeff_V, 0d0, a0), 1d-10)**egam/egam

        ! get utility function
        utility = - (cons**egam/egam + beta*vplus)

    end function

end module

Hi @msilva913, the Julia code is not running well on my machine either – the process hangs after one iteration.

Can you please let me know the Bellman equation you are working on, with enough detail to understand the problem? I can try to understand from the code but it would be a bit easier to see the Bellman equation first.

Of course! This is a simple cake-eating problem, so the Bellman equation is just
V(a) = max_{a’} u(a-a’) + \beta V(a’)

I should also mention that, rather than interpolating the value function directly, here I interpolate a transform of the value function, which is supposed to be closer to linear, and then use the inverse transformation to recover the value function in the following lines:

V_temp = CubicSplineInterpolation(a, (egam*V).^(1.0/egam))(a_prime)
V_fun = max(V_temp, 1e-10)^(egam)/egam

Of course, having 1000 grid points obviously increases the time significantly, but the corresponding Fortran code has the same-size grid.

Thanks @msilva913

I found that CubicSplineInterpolation is slow. I wrote a variation on your code that solves the problem using LinearInterpolation and it’s fast. I used the same parameters and grid size. It runs in about 2 seconds on my little laptop.

There are good reasons to use linear interpolation – it preserves the contractivity of the Bellman operator.

Please let me know how it works for you.


using Interpolations, Random, Optim, Statistics, PyPlot
import Base.@kwdef


@kwdef struct CakeEatingProblem
    γ::Float64 = 0.5
    β::Float64 = 0.95
end

# Utility function
u(c, γ) = c^(1 - γ) / (1 - γ)

# Right hand side of Bellman equation for each choice of consumption
function state_action_value(cep::CakeEatingProblem, 
                            a,              # assets
                            c,              # consumption 
                            grid,           # grid points
                            v_on_grid)      # guess of v on grid points
    (; β, γ) = cep     # Unpack
    v = LinearInterpolation(grid, v_on_grid)

    return u(c, γ) + β * v(a - c)
end

# Bellman oerator
function T(cep::CakeEatingProblem, grid, v_on_grid)
    Tv = similar(v_on_grid)
    for (i, a) in enumerate(grid)
        objective = c -> state_action_value(cep, a, c, grid, v_on_grid)
        results = maximize(objective, 0.0, a)
        Tv[i] = Optim.maximum(results)
    end
    return Tv
end

# Get the optimal policy correspoinding to a guess of the value function
function get_greedy(cep::CakeEatingProblem, grid, v_on_grid)
    c_policy = similar(v_on_grid)
    for (i, a) in enumerate(grid)
        objective = c -> state_action_value(cep, a, c, grid, v_on_grid)
        results = maximize(objective, 0.0, a)
        c_policy[i] = Optim.maximizer(results)
    end
    return c_policy
end

function value_function_iteration(cep::CakeEatingProblem,
                                  grid_min=0.0,
                                  grid_max=100.0,
                                  grid_length=1_000,
                                  max_iter=2_000,
                                  error_tol=1e-6)
    (; β, γ) = cep     # Unpack
    grid = range(grid_min, grid_max, length=grid_length)
    v_init = [u(c, γ) for c in grid]
    new_v = similar(v_init)

    iter_num = 0
    error = Inf
    v = v_init
    while error > error_tol && iter_num < max_iter
        new_v = T(cep, grid, v)
        error = maximum(abs.(new_v - v))
        iter_num += 1
        v = new_v
    end

    c_policy = get_greedy(cep, grid, v_init)

    return grid, new_v, c_policy
end


# Main
cep = CakeEatingProblem()

a, v, c = value_function_iteration(cep)

fig, ax = subplots(nrows=1, ncols=2, figsize=(10, 5))
ax[1].plot(a, c, label="policy function")
ax[2].plot(a, v, label="value function")
ax[1].legend()
ax[2].legend()
tight_layout()

display(fig)


end

Dear John:

Thanks so much! This code is very efficient! Yes, CubicSplineInterpolation is remarkably slow. However, it seems that I had introduced some additional overhead, as your code was faster than merely replacing CubicSplineInterpolation with LinearInterpolation. I need to scrutinize that more.

Thanks again!

Best,
Mario

Hi @msilva913 ,

glad that I could help.

I didn’t use the trick of transforming the value function in order to make it easier to interpolate. That might still be worth looking at. I think that trick is closely related to the idea of time iteration discussed here, although the code in that discussion is in Python.

Good luck with your studies.

John.

John has given a great answer but one thing to keep in mind is that your struct Para is not a concrete type because of the a parameter. Sometimes this is irrelevant and sometimes it is orders of magnitude difference in performance. In general, with Julia it is better to let julia figure out the types on its own, especially for anything more complicated than a scalar (e.g. use struct Para{T1} then a::T1 = range(0., 100., length=NA) etc.

Greetings!

Just to be clear, do you suggest to write the struct the following way?

@with_kw  struct Para{T1}
    # model parameters
    γ::Float64 = 0.5
    egam::Float64 = 1-1/γ
    β::Float64 = 0.95
    a0::Float64 = 100

    # numerical parameters
    σ::Float64 = 1e-6
    itermax::Int64 = 2000
    T::Int64 = 200
    NA::Int64 = 1000
    a::T1 = range(0., 100., length=NA)
end

Also, if I may abuse by asking a small follow-up question, are there major performance considerations in passing more arguments within a struct compared to having them as separate function arguments, or is that mostly a matter of style?

Thanks so much.