I’m solving an time depend equation that has the following form (finding the roots for \lambda):
import sympy as sp
import matplotlib.pyplot as plt
t = sp.symbols(r't', real=True, positive=True)
eq = ...
print(repr(eq))
-\lambda**3 - 2*\lambda**2*t + 4*\lambda*t**4 - 8*\lambda*t**3 + 8*\lambda*t**2 - 8*\lambda*t + 4*\lambda + 8*t**3 - 16*t**2 + 8*t
Solving the equation and saving the roots to a list:
sol = sp.solve(eq)
e_list = [list(sol[i].values())[0] for i in range(len(sol))]
Showing their evolution explicitly:
x =e_list[0]
lam_x = sp.lambdify(t, x, modules=['numpy'])
x_vals = np.linspace(0.01, .9999, 1000, dtype=complex)
y_vals = np.around(lam_x(x_vals),decimals=5)
plt.plot(np.real(x_vals), np.real(y_vals),label='r1')
x =e_list[1]
lam_x = sp.lambdify(t, x, modules=['numpy'])
x_vals = np.linspace(0.01, .9999, 1000, dtype=complex)
y_vals = np.around(lam_x(x_vals),decimals=5)
plt.plot(np.real(x_vals), np.real(y_vals),label='r2')
x =e_list[2]
lam_x = sp.lambdify(t, x, modules=['numpy'])
x_vals = np.linspace(0.01, .9999, 1000, dtype=complex)
y_vals = np.around(lam_x(x_vals),decimals=5)
plt.plot(np.real(x_vals), np.real(y_vals),label='r3')
plt.legend()
plt.show()
But for some reason the solutions for the the two firsts roots are mixed before ~ .72. I have no idea how to correct this, make then not mix
>Solution :
That behavior is to be expected with Numpy, as you are going trough a complex branch cut. In this case you should use Mpmath as the evaluation module for lambdify, which deals with branch cuts differently. For example:
import numpy as np
import matplotlib.pyplot as plt
x =e_list[0]
lam_x = lambdify(t, x, modules=['mpmath'])
x_vals = np.linspace(0.01, .9999, 1000, dtype=complex)
y_vals = []
for _x in x_vals:
y_vals.append(complex(lam_x(_x)))
plt.figure()
plt.plot(np.real(x_vals), np.real(y_vals),label='r1')
x =e_list[1]
lam_x = lambdify(t, x, modules=['mpmath'])
x_vals = np.linspace(0.01, .9999, 1000, dtype=complex)
y_vals = []
for _x in x_vals:
y_vals.append(complex(lam_x(_x)))
plt.plot(np.real(x_vals), np.real(y_vals),label='r2')
x =e_list[2]
lam_x = lambdify(t, x, modules=['numpy'])
x_vals = np.linspace(0.01, .9999, 1000, dtype=complex)
y_vals = np.around(lam_x(x_vals),decimals=5)
plt.plot(np.real(x_vals), np.real(y_vals),label='r3')
plt.legend()
plt.show()

