{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Optimization of hyper-parameters\n", "\n", "The model evidence cannot only be used to compare different kinds of time series models, but also to optimize the hyper-parameters of a given transition model by maximizing its evidence value. The `Study` class of *bayesloop* contains a method `optimize` which relies on the `minimize` function of the `scipy.optimize` module. Since *bayesloop* has no gradient information about the hyper-parameters, the optimization routine is based on the [COBYLA](https://en.wikipedia.org/wiki/COBYLA) algorithm. The following two sections introduce the optimization of hyper-parameters using *bayesloop* and further describe how to selectively optimize specific hyper-parameters in nested transition models." ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "collapsed": false, "execution": { "iopub.execute_input": "2026-04-27T19:16:59.427545Z", "iopub.status.busy": "2026-04-27T19:16:59.427212Z", "iopub.status.idle": "2026-04-27T19:17:00.460234Z", "shell.execute_reply": "2026-04-27T19:17:00.459801Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "+ Created new study.\n", "+ Successfully imported example data.\n", "+ Observation model: Poisson. Parameter(s): ['accident_rate']\n" ] } ], "source": [ "%matplotlib inline\n", "import matplotlib.pyplot as plt # plotting\n", "plt.style.use('seaborn-v0_8-whitegrid') # plot styling\n", "\n", "import numpy as np\n", "import bayesloop as bl\n", "\n", "# prepare study for coal mining data\n", "S = bl.Study()\n", "S.load_example_data()\n", "\n", "L = bl.om.Poisson('accident_rate', bl.oint(0, 6, 1000))\n", "S.set(L)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Global optimization\n", "\n", "The `optimize` method supports all currently implemented transition models with continuous hyper-parameters, as well as combinations of multiple models. The change-point model as well as the serial transition model represent exceptions here, as their parameters `t_change` and `t_break`, respectively, are discrete. These discrete parameters are ignored by the optimization routine. See the [tutorial on change-point studies](changepointstudy.ipynb) for further information on how to analyze structural breaks and change-points. By default, all continuous hyper-parameters of the transition model are optimized. *bayesloop* further allows to selectively optimize specific hyper-parameters, see [below](#Conditional-optimization-in-nested-transition-models). The parameter values set by the user when defining the transition model are used as starting values. During optimization, only the log-evidence of the model is computed. When finished, a full fit is done to provide the parameter distributions and mean values for the optimal model setting.\n", "\n", "We take up the coal mining example again, and stick with the serial transition model defined [here](modelselection.ipynb#Serial-transition-model). This time, however, we optimize the slope of the linear decrease from 1885 to 1895 and the magnitude of the fluctuations afterwards (i.e. the standard deviation of the Gaussian random walk):" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "collapsed": false, "execution": { "iopub.execute_input": "2026-04-27T19:17:00.461513Z", "iopub.status.busy": "2026-04-27T19:17:00.461384Z", "iopub.status.idle": "2026-04-27T19:17:01.231788Z", "shell.execute_reply": "2026-04-27T19:17:01.231414Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "+ Transition model: Serial transition model. Hyper-Parameter(s): ['slope', 'sigma', 't_1', 't_2']\n", "+ Starting optimization...\n", " --> All model parameters are optimized (except change/break-points).\n", " + Log10-evidence: -72.93384 - Parameter values: [-0.2 0.1]\n", " + Log10-evidence: -96.81252 - Parameter values: [0.8 0.1]\n", " + Log10-evidence: -75.18192 - Parameter values: [-0.2 1.1]\n", " + Log10-evidence: -78.43877 - Parameter values: [-1.19559753 0.00626874]\n", " + Log10-evidence: -139.79077 - Parameter values: [ 0.26557651 -0.08231432]\n", " + Log10-evidence: -76.83672 - Parameter values: [-0.44996974 0.09611054]\n", " + Log10-evidence: -73.19510 - Parameter values: [-0.19922211 0.05000605]\n", " + Log10-evidence: -74.29931 - Parameter values: [-0.10567583 0.13321071]\n", " + Log10-evidence: -73.37975 - Parameter values: [-0.29561058 0.12930218]\n", " + Log10-evidence: -73.13427 - Parameter values: [-0.20722357 0.05052455]\n", " + Log10-evidence: -73.07563 - Parameter values: [-0.17828693 0.11239124]\n", " + Log10-evidence: -72.90392 - Parameter values: [-0.20247825 0.10434261]\n", " + Log10-evidence: -72.85412 - Parameter values: [-0.2120502 0.10723705]\n", " + Log10-evidence: -72.81735 - Parameter values: [-0.21827602 0.11506259]\n", " + Log10-evidence: -72.81196 - Parameter values: [-0.22816886 0.11652268]\n", " + Log10-evidence: -72.78617 - Parameter values: [-0.22782114 0.12651663]\n", " + Log10-evidence: -72.76495 - Parameter values: [-0.22845054 0.13649681]\n", " + Log10-evidence: -72.74628 - Parameter values: [-0.23544537 0.15523373]\n", " + Log10-evidence: -72.71912 - Parameter values: [-0.21865359 0.16609818]\n", " + Log10-evidence: -72.72476 - Parameter values: [-0.2082647 0.18318827]\n", " + Log10-evidence: -72.74145 - Parameter values: [-0.21052728 0.1602704 ]\n", " + Log10-evidence: -72.71461 - Parameter values: [-0.22321455 0.16814701]\n", " + Log10-evidence: -72.71584 - Parameter values: [-0.22807714 0.16931112]\n", " + Log10-evidence: -72.71589 - Parameter values: [-0.22350557 0.16693136]\n", " + Log10-evidence: -72.71252 - Parameter values: [-0.22208028 0.17037489]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " + Log10-evidence: -72.70987 - Parameter values: [-0.21953447 0.17467825]\n", " + Log10-evidence: -72.71266 - Parameter values: [-0.21534419 0.17740616]\n", " + Log10-evidence: -72.70733 - Parameter values: [-0.2213498 0.17639713]\n", " + Log10-evidence: -72.70525 - Parameter values: [-0.22258147 0.17857268]\n", " + Log10-evidence: -72.70164 - Parameter values: [-0.22334162 0.18351456]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " + Log10-evidence: -72.70153 - Parameter values: [-0.22978427 0.19116259]\n", " + Log10-evidence: -72.69632 - Parameter values: [-0.2259854 0.19441353]\n", " + Log10-evidence: -72.69518 - Parameter values: [-0.21840356 0.20093393]\n", " + Log10-evidence: -72.69043 - Parameter values: [-0.22297007 0.20983039]\n", " + Log10-evidence: -72.69596 - Parameter values: [-0.23199007 0.22768087]\n", " + Log10-evidence: -72.69976 - Parameter values: [-0.21314869 0.21171198]\n", " + Log10-evidence: -72.68994 - Parameter values: [-0.22344047 0.21228573]\n", " + Log10-evidence: -72.69263 - Parameter values: [-0.22843977 0.21236963]\n", " + Log10-evidence: -72.69008 - Parameter values: [-0.221257 0.21350331]\n", " + Log10-evidence: -72.68989 - Parameter values: [-0.22412754 0.21301232]\n", " + Log10-evidence: -72.69017 - Parameter values: [-0.22512751 0.21302082]\n", " + Log10-evidence: -72.68968 - Parameter values: [-0.22347647 0.21377134]\n", " + Log10-evidence: -72.68955 - Parameter values: [-0.22291543 0.21459912]\n", " + Log10-evidence: -72.68940 - Parameter values: [-0.22303637 0.21559178]\n", " + Log10-evidence: -72.68915 - Parameter values: [-0.22296082 0.21759036]\n", " + Log10-evidence: -72.68893 - Parameter values: [-0.22200498 0.22147447]\n", " + Log10-evidence: -72.68928 - Parameter values: [-0.22565911 0.22310154]\n", " + Log10-evidence: -72.68968 - Parameter values: [-0.22007875 0.22201265]\n", " + Log10-evidence: -72.68887 - Parameter values: [-0.22213952 0.22195603]\n", " + Log10-evidence: -72.68876 - Parameter values: [-0.22313912 0.2219843 ]\n", " + Log10-evidence: -72.68877 - Parameter values: [-0.2238879 0.22264711]\n", " + Log10-evidence: -72.68879 - Parameter values: [-0.22345856 0.22159964]\n", " + Log10-evidence: -72.68876 - Parameter values: [-0.22295726 0.22215584]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " + Log10-evidence: -72.68876 - Parameter values: [-0.22274716 0.22229134]\n", " + Log10-evidence: -72.68876 - Parameter values: [-0.22293016 0.22211382]\n", " + Log10-evidence: -72.68875 - Parameter values: [-0.22303482 0.22221897]\n", " + Log10-evidence: -72.68874 - Parameter values: [-0.22319773 0.22233499]\n", " + Log10-evidence: -72.68874 - Parameter values: [-0.22339216 0.22238187]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " + Log10-evidence: -72.68874 - Parameter values: [-0.22318601 0.22238359]\n", " + Log10-evidence: -72.68873 - Parameter values: [-0.22315149 0.22247745]\n", " + Log10-evidence: -72.68872 - Parameter values: [-0.22307598 0.22266265]\n", " + Log10-evidence: -72.68870 - Parameter values: [-0.22299684 0.22305474]\n", " + Log10-evidence: -72.68871 - Parameter values: [-0.22271968 0.22334315]\n", " + Log10-evidence: -72.68870 - Parameter values: [-0.22315679 0.22317481]\n", " + Log10-evidence: -72.68869 - Parameter values: [-0.22333187 0.2232715 ]\n", " + Log10-evidence: -72.68871 - Parameter values: [-0.2234958 0.22315693]\n", " + Log10-evidence: -72.68869 - Parameter values: [-0.22330004 0.2233663 ]\n", " + Log10-evidence: -72.68868 - Parameter values: [-0.22327744 0.22346371]\n", " + Log10-evidence: -72.68868 - Parameter values: [-0.2232448 0.22366103]\n", " + Log10-evidence: -72.68866 - Parameter values: [-0.2232272 0.22406064]\n", " + Log10-evidence: -72.68864 - Parameter values: [-0.22292064 0.22479957]\n", " + Log10-evidence: -72.68863 - Parameter values: [-0.22352162 0.2253276 ]\n", " + Log10-evidence: -72.68861 - Parameter values: [-0.22347575 0.22612629]\n", " + Log10-evidence: -72.68859 - Parameter values: [-0.22320993 0.22770405]\n", " + Log10-evidence: -72.68877 - Parameter values: [-0.22453692 0.22859797]\n", " + Log10-evidence: -72.68865 - Parameter values: [-0.22241267 0.22763781]\n", " + Log10-evidence: -72.68859 - Parameter values: [-0.22319337 0.22790336]\n", " + Log10-evidence: -72.68861 - Parameter values: [-0.22360992 0.22770661]\n", " + Log10-evidence: -72.68860 - Parameter values: [-0.22301513 0.22765872]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " + Log10-evidence: -72.68859 - Parameter values: [-0.22330728 0.2276812 ]\n", "+ Finished optimization.\n", "+ Started new fit:\n", " + Formatted data.\n", " + Set prior (function): jeffreys. Values have been re-normalized.\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "f9437ebcf6464976a8995adf6230112f", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/110 [00:00" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# define linear decrease transition model\n", "def linear(t, slope=-0.2):\n", " return slope*t\n", "\n", "T = bl.tm.SerialTransitionModel(bl.tm.Static(),\n", " bl.tm.BreakPoint('t_1', 1885),\n", " bl.tm.Deterministic(linear, target='accident_rate'),\n", " bl.tm.BreakPoint('t_2', 1895),\n", " bl.tm.GaussianRandomWalk('sigma', 0.1, target='accident_rate'))\n", "S.set(T)\n", "\n", "S.optimize()\n", "\n", "plt.figure(figsize=(8, 4))\n", "plt.bar(S.raw_timestamps, S.raw_data, align='center', facecolor='r', alpha=.5)\n", "S.plot('accident_rate')\n", "plt.xlim([1851, 1962])\n", "plt.xlabel('year');" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The optimal value for the standard deviation of the varying disaster rate is determined to be $\\approx 0.23$, the initial guess of $\\sigma = 0.1$ is therefore too restrictive. The value of the slope is only optimized slightly, resulting in an optimal value of $\\approx -0.22$. The optimal hyper-parameter values are displayed in the output during optimization, but can also be inspected directly:" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "collapsed": false, "execution": { "iopub.execute_input": "2026-04-27T19:17:01.233089Z", "iopub.status.busy": "2026-04-27T19:17:01.232999Z", "iopub.status.idle": "2026-04-27T19:17:01.235075Z", "shell.execute_reply": "2026-04-27T19:17:01.234653Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "slope = -0.2232099271284006\n", "sigma = 0.22770405019767503\n" ] } ], "source": [ "print('slope =', S.get_hyper_parameter_value('slope'))\n", "print('sigma =', S.get_hyper_parameter_value('sigma'))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Conditional optimization in nested transition models\n", "\n", "The previous section introduced the `optimize` method of the `Study` class. By default, all (continuous) hyper-parameters of the chosen transition model are optimized. In some applications, however, only specific hyper-parameters may be subject to optimization. Therefore, a list of parameter names (or a single name) may be passed to `optimize`, specifying which parameters to optimize. Note that all hyper-parameters have to be given a unique name. An example for a (quite ridiculously) nested transition model is defined below. Note that the deterministic transition models are defined via `lambda` functions." ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "collapsed": false, "execution": { "iopub.execute_input": "2026-04-27T19:17:01.236097Z", "iopub.status.busy": "2026-04-27T19:17:01.236021Z", "iopub.status.idle": "2026-04-27T19:17:01.238526Z", "shell.execute_reply": "2026-04-27T19:17:01.238108Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "+ Transition model: Serial transition model. Hyper-Parameter(s): ['early_sigma', 'pmin', 'slope_1', 'late_sigma', 'slope_2', 'first_break', 'second_break']\n" ] } ], "source": [ "T = bl.tm.SerialTransitionModel(bl.tm.CombinedTransitionModel(\n", " bl.tm.GaussianRandomWalk('early_sigma', 0.05, target='accident_rate'),\n", " bl.tm.RegimeSwitch('pmin', -7)\n", " ),\n", " bl.tm.BreakPoint('first_break', 1885),\n", " bl.tm.Deterministic(lambda t, slope_1=-0.2: slope_1*t, target='accident_rate'),\n", " bl.tm.BreakPoint('second_break', 1895),\n", " bl.tm.CombinedTransitionModel(\n", " bl.tm.GaussianRandomWalk('late_sigma', 0.25, target='accident_rate'),\n", " bl.tm.Deterministic(lambda t, slope_2=0.0: slope_2*t, target='accident_rate')\n", " )\n", " )\n", "S.set(T)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This transition model assumes a combination of gradual and abrupt changes until 1885, followed by a deterministic decrease of the annual disaster rate until 1895. Afterwards, the disaster rate is modeled by a combination of a decreasing trend and random fluctuations. Instead of discussing exactly how meaningful the proposed transition model really is, we focus on how to specify different (groups of) hyper-parameters that we might want to optimize.\n", "\n", "All hyper-parameter names occur only once within the transition model and may simply be stated by their name: `S.optimize('pmin')`. Note that you may also pass a single or multiple hyper-parameter(s) as a list: `S.optimize(['pmin'])`, `S.optimize(['pmin', 'slope_2'])`. For deterministic models, the argument name also represents the hyper-parameter name:" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "collapsed": false, "execution": { "iopub.execute_input": "2026-04-27T19:17:01.239478Z", "iopub.status.busy": "2026-04-27T19:17:01.239404Z", "iopub.status.idle": "2026-04-27T19:17:01.579075Z", "shell.execute_reply": "2026-04-27T19:17:01.578677Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "+ Starting optimization...\n", " --> Parameter(s) to optimize: ['slope_2']\n", " + Log10-evidence: -72.78352 - Parameter values: [0.]\n", " + Log10-evidence: -93.84882 - Parameter values: [1.]\n", " + Log10-evidence: -80.98325 - Parameter values: [-1.]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " + Log10-evidence: -91.84552 - Parameter values: [0.5]\n", " + Log10-evidence: -76.50153 - Parameter values: [-0.25]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " + Log10-evidence: -75.23947 - Parameter values: [0.1]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " + Log10-evidence: -72.90930 - Parameter values: [-0.1]\n", " + Log10-evidence: -73.67309 - Parameter values: [0.05]\n", " + Log10-evidence: -72.58918 - Parameter values: [-0.025]\n", " + Log10-evidence: -72.55250 - Parameter values: [-0.05]\n", " + Log10-evidence: -72.66318 - Parameter values: [-0.075]\n", " + Log10-evidence: -72.54887 - Parameter values: [-0.04]\n", " + Log10-evidence: -72.56953 - Parameter values: [-0.03]\n", " + Log10-evidence: -72.54769 - Parameter values: [-0.045]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " + Log10-evidence: -72.55250 - Parameter values: [-0.05]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " + Log10-evidence: -72.54753 - Parameter values: [-0.0425]\n", " + Log10-evidence: -72.54788 - Parameter values: [-0.0415]\n", " + Log10-evidence: -72.54741 - Parameter values: [-0.0435]\n", " + Log10-evidence: -72.54754 - Parameter values: [-0.0445]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " + Log10-evidence: -72.54744 - Parameter values: [-0.043]\n", " + Log10-evidence: -72.54742 - Parameter values: [-0.04375]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " + Log10-evidence: -72.54741 - Parameter values: [-0.0434]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " + Log10-evidence: -72.54741 - Parameter values: [-0.0436]\n", "+ Finished optimization.\n", "+ Started new fit:\n", " + Formatted data.\n", " + Set prior (function): jeffreys. Values have been re-normalized.\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "1ac5068916c9452f9c25dc349aaf2778", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/110 [00:00