Double Pendulum

notion image
notion image
Double pendulum is defined by the following equations:
 
 
def lagrangian(q, q_dot, m1, m2, l1, l2, g): t1, t2 = q # theta 1 and theta 2 w1, w2 = q_dot # omega 1 and omega 2 # kinetic energy (T) T1 = 0.5 * m1 * (l1 * w1)**2 T2 = 0.5 * m2 * ((l1 * w1)**2 + (l2 * w2)**2 + 2 * l1 * l2 * w1 * w2 * jnp.cos(t1 - t2)) T = T1 + T2 # potential energy (V) y1 = -l1 * jnp.cos(t1) y2 = y1 - l2 * jnp.cos(t2) V = m1 * g * y1 + m2 * g * y2 return T - V
Solving the Euler-Lagrange equation for the system we get the following set of equations.
Assuming for
Insight
As long as we can put the equation like this, we don’t have to worry about the data of derivative. We can obtain it by solving RHS through diff equation solvers. And further do calculations using this as well.
The following code returns the derivative with respect to the time.
def f_analytical(state, t=0, m1=1, m2=1, l1=1, l2=1, g=9.8): t1, t2, w1, w2 = state a1 = (l2 / l1) * (m2 / (m1 + m2)) * jnp.cos(t1 - t2) a2 = (l1 / l2) * jnp.cos(t1 - t2) f1 = -(l2 / l1) * (m2 / (m1 + m2)) * (w22) * jnp.sin(t1 - t2) - \ (g / l1) * jnp.sin(t1) f2 = (l1 / l2) * (w12) * jnp.sin(t1 - t2) - (g / l2) * jnp.sin(t2) g1 = (f1 - a1 * f2) / (1 - a1 * a2) g2 = (f2 - a2 * f1) / (1 - a1 * a2) return jnp.stack([w1, w2, g1, g2])

Dynamics from the time derivative Equation

@partial(jax.jit, backend='cpu') def solve_analytical(initial_state, times): return odeint(f_analytical, initial_state, t=times, rtol=1e-10, atol=1e-10)
odeint
Gives you the solution to the differential equation for any initial value at specific times for any RHS function. (Here f_analytical). So we receive the quadruple solution.

Dynamics from the Lagrangian

In order to obtain the trajectory back from the learned Lagrangian Neural Network, is expressed as equations of gradients of our LNN, and integrated twice using odeint — a built in function in script to integrate.
def equation_of_motion(lagrangian, state, t=None): q, q_t = jnp.split(state, 2) q_tt = (jnp.linalg.pinv(jax.hessian(lagrangian, 1)(q, q_t)) @ (jax.grad(lagrangian, 0)(q, q_t) - jax.jacobian(jax.jacobian(lagrangian, 1), 0)(q, q_t) @ q_t)) return jnp.concatenate([q_t, q_tt]) def solve_lagrangian(lagrangian, initial_state, **kwargs): @partial(jax.jit, backend='cpu') def f(initial_state): return odeint(partial(equation_of_motion, lagrangian), initial_state, **kwargs) return f(initial_state)
 
 

Dynamics from the LNN Equation

@partial(jax.jit, backend='cpu') def solve_autograd(initial_state, times, m1=1, m2=1, l1=1, l2=1, g=9.8): L = partial(lagrangian, m1=m1, m2=m2, l1=l1, l2=l2, g=g) return solve_lagrangian(L, initial_state, t=times, rtol=1e-10, atol=1e-10)
 
Normalize
def normalize_dp(state): # wrap generalized coordinates to [-pi, pi] return jnp.concatenate([(state[:2] + np.pi) % (2 * np.pi) - np.pi, state[2:]])
 
Runga-Kutta Method
def rk4_step(f, x, t, h): # one step of runge-kutta integration k1 = h * f(x, t) k2 = h * f(x + k1/2, t + h/2) k3 = h * f(x + k2/2, t + h/2) k4 = h * f(x + k3, t + h) return x + 1/6 * (k1 + 2 * k2 + 2 * k3 + k4)

Generate training data

 
time_step = 0.01 N = 1500 analytical_step = jax.jit(jax.vmap(partial(rk4_step, f_analytical, t=0.0, h=time_step)))
jax.vmap
partial from functools
Returns a function with partial arguments filled up https://docs.python.org/3/library/functools.html
Initial value and time
x0 = np.array([3*np.pi/7, 3*np.pi/4, 0, 0], dtype=np.float32) t = np.arange(N, dtype=np.float32) # time steps 0 to N
Data generation
x_train is generated by using the function solve_analytical : (Initial state, times) → Integration of dynamics equation (f_analytical).
xt_train is generated by applying f_analytical to x_train
x_train = jax.device_get(solve_analytical(x0, t)) # dynamics for first N time steps xt_train = jax.device_get(jax.vmap(f_analytical)(x_train)) # time derivatives of each state y_train = jax.device_get(analytical_step(x_train)) # analytical next step
Question
Note that xt is generated by the applying the RHS function on the solution which was generated by numerically integrating that same function.
noise = np.random.RandomState(0).randn(x0.size) t_test = np.arange(N, 2*N, dtype=np.float32) # time steps N to 2N x_test = jax.device_get(solve_analytical(x0, t_test)) # dynamics for next N time steps xt_test = jax.device_get(jax.vmap(f_analytical)(x_test)) # time derivatives of each state y_test = jax.device_get(analytical_step(x_test)) # analytical next step
 

Network

Replace the Lagrangian with a parametric model. Here the function learned_lagrangian is a function of learnable params and returns the Lagrangian function defined inside it as a neural network.
def learned_lagrangian(params): def lagrangian(q, q_t): assert q.shape == (2,) state = normalize_dp(jnp.concatenate([q, q_t])) return jnp.squeeze(nn_forward_fn(params, state), axis=-1) return lagrangian
 
init_random_params, nn_forward_fn = stax.serial( stax.Dense(128), stax.Softplus, stax.Dense(128), stax.Softplus, stax.Dense(1), )

Loss function

def loss(params, batch, time_step=None): state, targets = batch if time_step is not None: f = partial(equation_of_motion, learned_lagrangian(params)) preds = jax.vmap(partial(rk4_step, f, t=0.0, h=time_step))(state) else: preds = jax.vmap(partial(equation_of_motion, learned_lagrangian(params)))(state) return jnp.mean((preds - targets) ** 2)
 
badge