import json
import os
import random
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
cols = ["red", "green", "blue", "royalblue", "orange", "black"]
[docs]
def plot_gLV(yobs, timepoints):
# fig, axs = plt.subplots(1, 2, layout='constrained') # Optional
# alternative
fig, axs = plt.subplots(1, 1)
for species_idx in range(yobs.shape[1]):
label = f'Species {species_idx + 1}' # Add a label for each species
axs.plot(timepoints, yobs[:, species_idx],
color=cols[species_idx], label=label)
axs.set_xlabel('time')
axs.set_ylabel('[species]')
axs.legend() # Ensure the legend is called on the correct axes
plt.show()
[docs]
def plot_CRM(observed_species, observed_resources, timepoints, csv_file=None):
# Create a single axis
fig, ax = plt.subplots(1, 1, figsize=(10, 6))
cols = plt.cm.tab10.colors
total_entities = observed_species.shape[1] + observed_resources.shape[1]
# Plot each species
for species_idx in range(observed_species.shape[1]):
label = f'Species {species_idx + 1}'
ax.plot(timepoints,
observed_species[:,
species_idx],
color=cols[species_idx],
label=label)
# Plot each resource - using distinct colors that continue from where
# species left off
for resource_idx in range(observed_resources.shape[1]):
# Use a different color index for resources (continuing from where
# species left off)
color_idx = observed_species.shape[1] + resource_idx
color_idx = color_idx % len(cols)
label = f'Resource {resource_idx + 1}'
ax.plot(timepoints,
observed_resources[:,
resource_idx],
linestyle='--',
color=cols[color_idx],
label=label)
# If CSV file is provided, overlay the observed data
if csv_file:
import pandas as pd
data = pd.read_csv(csv_file)
# Extract time and data columns
time_col = data.columns[0] # Assuming first column is time
# Plot observed species data with markers
num_species = observed_species.shape[1]
for i in range(num_species):
species_col = f'species_{i+1}'
if species_col in data.columns:
ax.scatter(data[time_col],
data[species_col],
marker='o',
color=cols[i % len(cols)],
s=10,
alpha=0.7,
label=f'Observed {species_col}')
# Plot observed resource data with different markers and consistent
# colors with simulated resources
num_resources = observed_resources.shape[1]
for i in range(num_resources):
resource_col = f'resource_{i+1}'
if resource_col in data.columns:
color_idx = num_species + i
color_idx = color_idx % len(cols)
ax.scatter(data[time_col], data[resource_col],
marker='s', color=cols[color_idx], s=10, alpha=0.7,
label=f'Observed {resource_col}')
# Set axis labels
ax.set_xlabel('Time', fontsize=12)
ax.set_ylabel('Concentration', fontsize=12)
ax.set_title('CRM Growth Curves: Simulated vs Observed', fontsize=14)
# Add a legend to label both species and resources
ax.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
# Add grid for better readability
ax.grid(True, linestyle='--', alpha=0.7)
# Adjust layout to make room for the legend
plt.tight_layout()
# Show the plot
plt.show()
return fig, ax
[docs]
def plot_CRM_with_intervals(
observed_species,
observed_resources,
species_lower,
species_upper,
resource_lower,
resource_upper,
times,
filename=None):
fig, ax = plt.subplots(figsize=(12, 8))
# Plot median trajectories
for i in range(observed_species.shape[1]):
ax.plot(times, observed_species[:, i],
label=f'Species {i+1}', linewidth=2)
for i in range(observed_resources.shape[1]):
ax.plot(times,
observed_resources[:,
i],
label=f'Resource {i+1}',
linewidth=2,
linestyle='--')
# Add confidence ribbons
for i in range(observed_species.shape[1]):
ax.fill_between(times, species_lower[:, i], species_upper[:, i],
alpha=0.2, color=plt.cm.tab10(i))
for i in range(observed_resources.shape[1]):
ax.fill_between(times,
resource_lower[:,
i],
resource_upper[:,
i],
alpha=0.2,
color=plt.cm.tab10(i + observed_species.shape[1]))
if filename:
true_data = pd.read_csv(filename)
true_times = true_data['time'].values
for i in range(observed_species.shape[1]):
col_name = f'species_{i+1}'
if col_name in true_data.columns:
ax.scatter(
true_times,
true_data[col_name],
marker='o',
s=30,
color=plt.cm.tab10(i),
label=f'True {col_name}')
for i in range(observed_resources.shape[1]):
col_name = f'resource_{i+1}'
if col_name in true_data.columns:
ax.scatter(
true_times,
true_data[col_name],
marker='s',
s=30,
color=plt.cm.tab10(
i + observed_species.shape[1]),
label=f'True {col_name}')
ax.set_xlabel('Time', fontsize=14)
ax.set_ylabel('Concentration', fontsize=14)
ax.set_title(
'Consumer-Resource Model Dynamics with 95% Credible Intervals',
fontsize=16)
ax.legend(loc='best', fontsize=12)
ax.grid(True, alpha=0.3)
plt.tight_layout()
if filename:
plt.savefig(f"{filename.split('.')[0]}_with_intervals.png", dpi=300)
plt.show()
[docs]
def plot_gMLV(yobs, sobs, timepoints):
# fig, axs = plt.subplots(1, 2, layout='constrained')
fig, axs = plt.subplots(1, 2)
for species_idx in range(yobs.shape[1]):
axs[0].plot(timepoints, yobs[:, species_idx], color=cols[species_idx])
axs[0].set_xlabel('time')
axs[0].set_ylabel('[species]')
if sobs.shape[1] > 0:
for metabolite_idx in range(sobs.shape[1]):
axs[1].plot(timepoints, sobs[:, metabolite_idx],
color=cols[metabolite_idx])
axs[1].set_xlabel('time')
axs[1].set_ylabel('[metabolite]')
[docs]
def plot_fit_gMLV(yobs, yobs_h, sobs, sobs_h, timepoints):
# plot the fit
# fig, axs = plt.subplots(1, 2, layout='constrained')
fig, axs = plt.subplots(1, 2)
for species_idx in range(yobs.shape[1]):
axs[0].plot(timepoints, yobs[:, species_idx], color=cols[species_idx])
axs[0].plot(timepoints, yobs_h[:, species_idx],
'--', color=cols[species_idx])
axs[0].set_xlabel('time')
axs[0].set_ylabel('[species]')
for metabolite_idx in range(sobs.shape[1]):
axs[1].plot(timepoints, sobs[:, metabolite_idx],
color=cols[metabolite_idx])
axs[1].plot(timepoints, sobs_h[:, metabolite_idx],
'--', color=cols[metabolite_idx])
axs[1].set_xlabel('time')
axs[1].set_ylabel('[metabolite]')
[docs]
def plot_fit_gLV(yobs, yobs_h, timepoints):
# plot the fit
# fig, axs = plt.subplots(1, 2, layout='constrained')
fig, axs = plt.subplots(1, 1)
for species_idx in range(yobs.shape[1]):
axs.plot(timepoints, yobs[:, species_idx], color=cols[species_idx])
axs.plot(timepoints, yobs_h[:, species_idx],
'--', color=cols[species_idx])
axs.set_xlabel('time')
axs.set_ylabel('[species]')
# def compare_params(mu=None, M=None, alpha=None, e=None):
# # each argument is a tuple of true and predicted values (mu, mu_hat)
# if mu is not None:
# print("mu_hat/mu:")
# print(np.array(mu[1]))
# print(np.array(mu[0]))
# fig, ax = plt.subplots()
# ax.stem(np.arange(0, len(mu[0]), dtype="int32"),
# np.array(mu[1]), markerfmt="D", label='mu_hat', linefmt='C0-')
# ax.stem(np.arange(0, len(mu[0]), dtype="int32"),
# np.array(mu[0]), markerfmt="X", label='mu', linefmt='C1-')
# ax.set_xlabel('i')
# ax.set_ylabel('mu[i]')
# ax.legend()
# if M is not None:
# print("\nM_hat/M:")
# print(np.round(np.array(M[1]), decimals=2))
# print("\n", np.array(M[0]))
# fig, ax = plt.subplots()
# ax.stem(
# np.arange(
# 0,
# M[0].shape[0] ** 2),
# np.array(
# M[1]).flatten(),
# markerfmt="D",
# label='M_hat',
# linefmt='C0-')
# ax.stem(
# np.arange(
# 0,
# M[0].shape[0] ** 2),
# np.array(
# M[0]).flatten(),
# markerfmt="X",
# label='M',
# linefmt='C1-')
# ax.set_ylabel('M[i,j]')
# ax.legend()
# if alpha is not None:
# print("\na_hat/a:")
# print(np.round(np.array(alpha[1]), decimals=2))
# print("\n", np.array(alpha[0]))
# fig, ax = plt.subplots()
# ax.stem(
# np.arange(
# 0,
# alpha[0].shape[0] *
# alpha[0].shape[1]),
# np.array(
# alpha[1]).flatten(),
# markerfmt="D",
# label='a_hat',
# linefmt='C0-')
# ax.stem(
# np.arange(
# 0,
# alpha[0].shape[0] *
# alpha[0].shape[1]),
# np.array(
# alpha[0]).flatten(),
# markerfmt="X",
# label='a',
# linefmt='C1-')
# ax.set_ylabel('a[i,j]')
# ax.legend()
# if e is not None:
# print("\ne_hat/e:")
# print(np.round(np.array(e[1]), decimals=2))
# print("\n", np.array(e[0]))
# fig, ax = plt.subplots()
# ax.stem(np.arange(0, e[0].shape[0]), np.array(
# e[1]).flatten(), markerfmt="D", label='e_hat', linefmt='C0-')
# ax.stem(np.arange(0, e[0].shape[0]), np.array(
# e[0]).flatten(), markerfmt="X", label='e', linefmt='C1-')
# ax.set_ylabel('e[i]')
# ax.legend()
[docs]
def compare_params(**kwargs):
"""
Compare inferred and observed parameters with any parameter name.
Parameters:
----------
**kwargs : Each argument should be a tuple of (true_value, predicted_value)
where true_value and predicted_value are numpy arrays
"""
import numpy as np
import matplotlib.pyplot as plt
for param_name, param_values in kwargs.items():
true_val, pred_val = param_values
# Print comparison
print(f"\n{param_name}_hat/{param_name}:")
print(np.round(np.array(pred_val), decimals=2))
print("\n", np.array(true_val))
# Create figure
fig, ax = plt.subplots()
# Handle different shapes of parameters
true_array = np.array(true_val)
pred_array = np.array(pred_val)
# Determine x-axis size based on parameter shape
if true_array.ndim > 1:
x_size = true_array.size
true_flat = true_array.flatten()
pred_flat = pred_array.flatten()
else:
x_size = len(true_array)
true_flat = true_array
pred_flat = pred_array
# Create stem plots
ax.stem(
np.arange(0, x_size, dtype="int32"),
pred_flat,
markerfmt="D",
label=f'{param_name}_hat',
linefmt='C0-'
)
ax.stem(
np.arange(0, x_size, dtype="int32"),
true_flat,
markerfmt="X",
label=f'{param_name}',
linefmt='C1-'
)
# Set labels and legend
ax.set_xlabel('index')
ax.set_ylabel(
f'{param_name}[i]' if true_array.ndim == 1 else f'{param_name}[i,j]')
ax.legend()
# Show plot
plt.tight_layout()
plt.show()
[docs]
def set_all_seeds(seed):
np.random.seed(seed)
random.seed(seed)
[docs]
def read_parameters(json_file):
current_dir = os.path.dirname(os.path.realpath(__file__))
file_path = os.path.join(current_dir, json_file)
with open(file_path, 'r') as f:
parameters = json.load(f)
return parameters