{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "f07fa1f2-187e-4ce0-af95-31d6120977fe", "metadata": { "ExecuteTime": { "end_time": "2024-10-31T11:15:47.604112Z", "start_time": "2024-10-31T11:15:47.544336Z" }, "pycharm": { "name": "#%%\n" } }, "outputs": [], "source": [ "from mimic.utilities import *\n", "\n", "from mimic.model_infer.infer_gLV_bayes import *\n", "from mimic.model_infer import *\n", "from mimic.model_simulate import *\n", "from mimic.model_simulate.sim_gLV import *\n", "\n", "import pandas as pd\n", "import numpy as np\n", "import seaborn as sns\n", "import matplotlib.pyplot as plt\n", "\n", "import arviz as az\n", "import pymc as pm\n", "import pytensor\n", "import pytensor.tensor as at\n", "import pickle\n", "import cloudpickle\n" ] }, { "cell_type": "markdown", "id": "82eb9f01", "metadata": { "pycharm": { "name": "#%% md\n" } }, "source": [ "## Used Bayesian inference to infer the parameters of a (linearised) gLV model\n", "\n", "The generalized Lotka-Volterra equation takes the form\n", "\n", "$$ \\frac{dX_i}{dt} = \\mu_i X_i + X_i M_{ij} X_j + X_i \\epsilon_{il} u_l $$\n", "\n", "where:\n", "- $X_i$ is the concentration of a species\n", "- $\\mu_i$ is its specific growth rate\n", "- $M_{ij}$ is the effect of the interaction of species $i$ on species $j$\n", "- $\\epsilon_{il}$ is the susceptibility to the time-dependent perturbation $u_l$" ] }, { "cell_type": "markdown", "id": "b4324950", "metadata": {}, "source": [ "### Bayesian inference with no shrinkage " ] }, { "cell_type": "code", "execution_count": 2, "id": "ed7e4c8c", "metadata": { "ExecuteTime": { "end_time": "2024-10-31T11:17:37.557214Z", "start_time": "2024-10-31T11:16:01.339288Z" } }, "outputs": [ { "ename": "FileNotFoundError", "evalue": "[Errno 2] No such file or directory: 'params-s5.pkl'", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mFileNotFoundError\u001b[0m Traceback (most recent call last)", "Cell \u001b[0;32mIn[2], line 3\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[38;5;66;03m# read in pickled simulated parameters, mu, M, epsilon\u001b[39;00m\n\u001b[1;32m 2\u001b[0m num_species \u001b[38;5;241m=\u001b[39m \u001b[38;5;241m5\u001b[39m\n\u001b[0;32m----> 3\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m \u001b[38;5;28;43mopen\u001b[39;49m\u001b[43m(\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mparams-s5.pkl\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mrb\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m)\u001b[49m \u001b[38;5;28;01mas\u001b[39;00m f:\n\u001b[1;32m 4\u001b[0m params \u001b[38;5;241m=\u001b[39m pickle\u001b[38;5;241m.\u001b[39mload(f)\n\u001b[1;32m 5\u001b[0m M \u001b[38;5;241m=\u001b[39m params[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mM\u001b[39m\u001b[38;5;124m\"\u001b[39m]\n", "File \u001b[0;32m/opt/anaconda3/envs/my_mimic_env/lib/python3.10/site-packages/IPython/core/interactiveshell.py:324\u001b[0m, in \u001b[0;36m_modified_open\u001b[0;34m(file, *args, **kwargs)\u001b[0m\n\u001b[1;32m 317\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m file \u001b[38;5;129;01min\u001b[39;00m {\u001b[38;5;241m0\u001b[39m, \u001b[38;5;241m1\u001b[39m, \u001b[38;5;241m2\u001b[39m}:\n\u001b[1;32m 318\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\n\u001b[1;32m 319\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mIPython won\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mt let you open fd=\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mfile\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m by default \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 320\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mas it is likely to crash IPython. If you know what you are doing, \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 321\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124myou can use builtins\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m open.\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 322\u001b[0m )\n\u001b[0;32m--> 324\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mio_open\u001b[49m\u001b[43m(\u001b[49m\u001b[43mfile\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", "\u001b[0;31mFileNotFoundError\u001b[0m: [Errno 2] No such file or directory: 'params-s5.pkl'" ] } ], "source": [ "# read in pickled simulated parameters, mu, M, epsilon, created in examples-sim-gLV.ipynb\n", "num_species = 5\n", "with open(\"params-s5.pkl\", \"rb\") as f:\n", " params = pickle.load(f)\n", "M = params[\"M\"]\n", "mu = params[\"mu\"]\n", "epsilon = params[\"epsilon\"]\n", "\n", "# read in the data\n", "num_timecourses = 1\n", "data = pd.read_csv(\"data-s5-r1.csv\")\n", "times = data.iloc[:, 0].values\n", "\n", "yobs = data.iloc[:, 1:6].values\n", "\n", "X, F = linearize_time_course_16S(yobs, times)\n", "\n", "# Define priors\n", "prior_mu_mean = 1.0 \n", "prior_mu_sigma = 0.5\n", "\n", "## NB prior_Mii_mean is 0, so not defined as an argument\n", "prior_Mii_mean = 0.0\n", "prior_Mii_sigma = 0.1\n", "\n", "prior_Mij_sigma = 0.1\n", "\n", "\n", "# Sampling conditions\n", "draws = 500\n", "tune = 500\n", "chains = 4\n", "cores = 4\n", "\n", "inference = infergLVbayes()\n", "\n", "inference.set_parameters(X=X, F=F, prior_mu_mean=prior_mu_mean, prior_mu_sigma=prior_mu_sigma,\n", " prior_Mii_sigma=prior_Mii_sigma, prior_Mii_mean=prior_Mii_mean,\n", " prior_Mij_sigma=prior_Mij_sigma,\n", " draws=draws, tune=tune, chains=chains,cores=cores)\n", "\n", "idata = inference.run_inference()\n", "\n", "# To plot posterior distributions\n", "#inference.plot_posterior(idata)\n", "\n", "\n", "summary = az.summary(idata, var_names=[\"mu_hat\", \"M_ii_hat\", \"M_ij_hat\", \"M_hat\", \"sigma\"])\n", "print(summary[[\"mean\", \"sd\", \"r_hat\"]])\n", "\n", "# Save posterior samples to file\n", "az.to_netcdf(idata, 'model_posterior.nc')\n", "\n", "# get median mu_hat and M_hat \n", "mu_h = np.median(idata.posterior[\"mu_hat\"].values, axis=(0,1) ).reshape(-1)\n", "M_h = np.median(idata.posterior[\"M_hat\"].values, axis=(0,1) )\n", "\n", "# compare fitted with simulated parameters\n", "compare_params(mu=(mu, mu_h), M=(M, M_h))\n", "\n", "predictor = sim_gLV(num_species=num_species, M=M_h.T, mu=mu_h)\n", "yobs_h, _, _, _, _ = predictor.simulate(times=times, init_species=yobs[0])\n", "plot_fit_gLV(yobs, yobs_h, times)\n" ] }, { "cell_type": "code", "execution_count": 3, "id": "7fcb005f031b1ce7", "metadata": { "ExecuteTime": { "end_time": "2024-10-31T11:20:49.557644Z", "start_time": "2024-10-31T11:20:49.420375Z" }, "collapsed": false }, "outputs": [], "source": [ "# read in pickled simulated parameters, mu, M, epsilon\n", "num_species = 5\n", "with open(\"params-s5.pkl\", \"rb\") as f:\n", " params = pickle.load(f)\n", "M = params[\"M\"]\n", "mu = params[\"mu\"]\n", "epsilon = params[\"epsilon\"]\n", "\n", "# read in the data\n", "num_timecourses = 3\n", "data = pd.read_csv(\"data-s5-r3.csv\")\n", "times = data.iloc[:, 0].values\n", "\n", "yobs_1 = data.iloc[:, 1:(num_species+1)].values\n", "yobs_2 = data.iloc[:, (num_species+1):(2*num_species+1)].values\n", "yobs_3 = data.iloc[:, (2*num_species+1):(3*num_species+1)].values\n", "ryobs = np.array([yobs_1, yobs_2, yobs_3])\n", "\n", "\n", "X = np.array([], dtype=np.double).reshape(0, num_species+1)\n", "F = np.array([], dtype=np.double).reshape(0, num_species)\n", "\n", "\n", "\n", "for timecourse_idx in range(num_timecourses):\n", " Xs, Fs = linearize_time_course_16S(ryobs[timecourse_idx], times)\n", " X = np.vstack([X, Xs])\n", " F = np.vstack([F, Fs])\n", " \n", "init_species = ryobs[timecourse_idx,0,:] \n" ] }, { "cell_type": "code", "execution_count": null, "id": "d13d05a1", "metadata": { "ExecuteTime": { "end_time": "2024-10-31T11:21:52.005379Z", "start_time": "2024-10-31T11:20:50.711980Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "X shape: (297, 6)\n", "F shape: (297, 5)\n", "Number of species: 5\n", "AdvancedSetSubtensor.0\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Initializing NUTS using jitter+adapt_diag...\n", "Multiprocess sampling (4 chains in 4 jobs)\n", "NUTS: [sigma, mu_hat, M_ii_hat_p, M_ij_hat]\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "0b6029cf4bee4f868421af968189d4bd", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Output()" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
\n" ], "text/plain": [] }, "metadata": {}, "output_type": "display_data" }, { "name": "stderr", "output_type": "stream", "text": [ "Sampling 4 chains for 500 tune and 500 draw iterations (2_000 + 2_000 draws total) took 37 seconds.\n", "/Users/chaniaclare/Documents/GitHub/MIMIC/venv/lib/python3.10/site-packages/arviz/stats/diagnostics.py:596: RuntimeWarning: invalid value encountered in scalar divide\n", " (between_chain_variance / within_chain_variance + num_samples - 1) / (num_samples)\n", "/Users/chaniaclare/Documents/GitHub/MIMIC/venv/lib/python3.10/site-packages/arviz/stats/diagnostics.py:596: RuntimeWarning: invalid value encountered in scalar divide\n", " (between_chain_variance / within_chain_variance + num_samples - 1) / (num_samples)\n", "/Users/chaniaclare/Documents/GitHub/MIMIC/venv/lib/python3.10/site-packages/arviz/stats/diagnostics.py:596: RuntimeWarning: invalid value encountered in scalar divide\n", " (between_chain_variance / within_chain_variance + num_samples - 1) / (num_samples)\n", "/Users/chaniaclare/Documents/GitHub/MIMIC/venv/lib/python3.10/site-packages/arviz/stats/diagnostics.py:596: RuntimeWarning: invalid value encountered in scalar divide\n", " (between_chain_variance / within_chain_variance + num_samples - 1) / (num_samples)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "mu_hat/mu:\n", "[1.3832308 0.74120724 1.77738998 0.95705513 0.80500161]\n", "[1.27853844 0.55683415 2.06752757 0.86387608 0.70448068]\n", "\n", "M_hat/M:\n", "[[-0.06 -0. 0. -0. 0. ]\n", " [ 0. -0.09 0. 0. 0. ]\n", " [-0.03 -0.01 -0.13 -0.01 0. ]\n", " [-0. 0.04 0. -0.01 0. ]\n", " [ 0. 0.01 0. 0. -0.15]]\n", "\n", " [[-0.05 0. -0.025 0. 0. ]\n", " [ 0. -0.1 0. 0.05 0. ]\n", " [ 0. 0. -0.15 0. 0. ]\n", " [ 0. 0. 0. -0.01 0. ]\n", " [ 0.02 0. 0. 0. -0.2 ]]\n" ] }, { "data": { "image/png": "", "text/plain": [ "