Source code for cosmoTransitions.multi_field_plotting

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import matplotlib.pyplot as plt
import numpy as np

import sys
if sys.version_info >= (3,0):
    xrange = range

[docs]class MultiFieldPlotter: """ This class tries to make it easier to view functions of more than two variables. For each set of two variables (or 'fields', since this is part of the CosmoTransitions package), this class will display a separate subplot in a managed figure. Each subplot is a different slice through the multi-dimensional space. By clicking on the subplots, the user can dynamically change the offsets of the slices in the other subplots. Parameters ---------- bounds : array_like A list of ``(xmin, xmax)`` tuples for each dimension. f : callable The function to plot. The first argument must accept arrays of shape ``(..., Ndim)``, where `Ndim` is the number of dimensions. f_args : tuple, optional Extra agruments to pass to `f`. nx : int, optional Number of data points to plot in each dimension. contour_levs : int or array_like, optional If an array, a list of the contour levels to plot. If a list, the total number of contour levels across the bounding box (the contour levels are then calculated using :func:`calcContourLevels`). plot_1d : bool, optional If True, plot one-dimensional plots along with the contours. (not yet implemented) plot_flipped : bool, optional If True, plot the flipped contour for each field (so that the subplots form a square grid rather than a triangle). Attributes ---------- figure : matplotlib.figure.Figure offset : array_like Each slice interesects the point given by `offset`. Initially set to the average of `bounds` and interactively modifiable by clicking on the plots. draws_offset : bool Set to True if the plots should draw the offset point (as intersecting lines). Example ------- The following example will make three contour plots whose offsets can be changed interactively: >>> from multi_field_plotting import MultiFieldPlotter >>> def V(X): # Some potential that looks vaguely interesting ... x,y,z = X[...,0], X[...,1], X[...,2] ... return x*x - x**3 + x*y + y**2 - y*z**2 + z**4 >>> mfp = MultiFieldPlotter([[-1,1.],[-1,1],[-1,1]], V) """ def __init__(self, bounds, f, f_args=(), nx=40, contour_levs=50, plot_1d=False, plot_flipped=False): self.bounds = np.array(bounds) self.f = f self.f_args = f_args self.nx = nx self.contour_levs = (contour_levs) self.contour_levs = np.array(contour_levs) if len(self.contour_levs.shape) == 0: self.calcContourLevels(self.contour_levs) self.plot_1d = plot_1d self.plot_flipped = plot_flipped self.figure = plt.figure() # Make the offset the center of the data bounds self.offset = np.average(bounds, axis=1) self.draws_offset = True if len(self.bounds) > 2 else False self.figure.canvas.mpl_connect('button_press_event', self._mouseDown) self.drawSubplot()
[docs] def calcContourLevels(self, num_levs, nx=11): """ Find the contour levels which span the bounds. Store in ``self.contour_levs``. Parameters ---------- num_levs : int Desired number of contour levels. nx : int, optional The number of data points along each dimension that are used to find the minimum and maximum levels. """ Ndim = len(self.bounds) X = np.empty([nx]*Ndim + [Ndim]) for i in xrange(Ndim): xmin, xmax = self.bounds[i] Y = X.swapaxes(i, -2) Y[...,i] = np.linspace(xmin,xmax,nx) Z = self.f(X, *self.f_args) fmin = np.min(Z.ravel()) fmax = np.max(Z.ravel()) df = fmax-fmin self.contour_levs = np.linspace(fmin-df*.1, fmax+df*.1, num_levs*1.2)
[docs] def drawSubplot(self, subplot='all'): """ Performs the actual drawing. Parameters ---------- subplot : (int, int) or 'all' The subplot to redraw. If a tuple, it should be field indicies of the x and y axes. """ Ndim = len(self.bounds) if subplot == 'all': for i in xrange(Ndim): for j in xrange(Ndim): self.drawSubplot((i,j)) return if not self.plot_1d and subplot[0] == subplot[1]: return if not self.plot_flipped and subplot[0] > subplot[1]: return if self.plot_1d or self.plot_flipped: nrows_cols = Ndim plot_num = 1+subplot[0] + nrows_cols*subplot[1] else: nrows_cols = Ndim - 1 plot_num = 1+subplot[0] + nrows_cols*(subplot[1]-1) ax = self.figure.add_subplot(nrows_cols,nrows_cols,plot_num) ax.clear() ax.xfield, ax.yfield = subplot if ax.yfield == Ndim-1: ax.set_xlabel("$x_%i$" % ax.xfield) if ax.xfield == 0: if ax.yfield == 0: ax.set_ylabel("$f(x_0)$") else: ax.set_ylabel("$x_%i$" % ax.yfield) # Generate the data and make the plot if ax.xfield == ax.yfield: pass # 1d_plot else: X = np.empty((self.nx, self.nx, Ndim)) X[:] = self.offset X[:,:,ax.xfield] = np.linspace( self.bounds[ax.xfield,0], self.bounds[ax.xfield,1], self.nx )[:,np.newaxis] * np.ones((self.nx, self.nx)) X[:,:,ax.yfield] = np.linspace( self.bounds[ax.yfield,0], self.bounds[ax.yfield,1], self.nx )[np.newaxis,:] * np.ones((self.nx, self.nx)) Z = self.f(X, *self.f_args) ax.contour( X[:,:,ax.xfield], X[:,:,ax.yfield], Z, self.contour_levs, # ax.pcolormesh(X[:,:,ax.xfield], X[:,:,ax.yfield], Z, # if self.draws_offset: xbounds = self.bounds[ax.xfield] ybounds = self.bounds[ax.yfield] x0 = self.offset[ax.xfield] y0 = self.offset[ax.yfield] ax.plot(xbounds, [y0,y0], 'k', lw=1.) ax.plot([x0,x0], ybounds, 'k', lw=1.)
def _mouseDown(self, event): ax = event.inaxes if not ax: return self.offset[ax.xfield] = event.xdata self.offset[ax.yfield] = event.ydata self.drawSubplot()