# A Velocity-Verlet integrator implemented in Python.
# Use this as a starting point for modified integrators.
#

from MMTK import *
from MMTK.Proteins import Protein
from MMTK.ForceFields import Amber99ForceField
from MMTK.Trajectory import Trajectory, TrajectoryOutput, SnapshotGenerator

# Velocity Verlet integrator in Python
def doVelocityVerletSteps(delta_t, nsteps,
                          equilibration_temperature = None,
                          equilibration_frequency = 1):
    configuration = universe.configuration()
    velocities = universe.velocities()
    gradients = ParticleVector(universe)
    inv_masses = 1./universe.masses()
    evaluator = universe.energyEvaluator()
    energy, gradients = evaluator(gradients)
    dv = -0.5*delta_t*gradients*inv_masses
    time = 0.
    snapshot(data={'time': time,
                   'potential_energy': energy})
    for step in range(nsteps):
        velocities += dv
        configuration += delta_t*velocities
        universe.setConfiguration(configuration)
        energy, gradients = evaluator(gradients)
        dv = -0.5*delta_t*gradients*inv_masses
        velocities += dv
        universe.setVelocities(velocities)
        time += delta_t
        snapshot(data={'time': time,
                       'potential_energy': energy})
        if equilibration_temperature is not None \
           and step % equilibration_frequency == 0:
            universe.scaleVelocitiesToTemperature(equilibration_temperature)

# Define system
universe = InfiniteUniverse(Amber99ForceField())
universe.protein = Protein('bala1')

# Create trajectory and snapshot generator
trajectory = Trajectory(universe, "md_trajectory.nc", "w",
                        "Generated by a Python integrator")
snapshot = SnapshotGenerator(universe,
                             actions = [TrajectoryOutput(trajectory,
                                                         ["all"], 0, None, 1)])

# Initialize velocities
universe.initializeVelocitiesToTemperature(50.*Units.K)
# Heat and equilibrate
for temperature in [50., 100., 200., 300.]:
    doVelocityVerletSteps(delta_t = 1.*Units.fs, nsteps = 500,
                          equilibration_temperature = temperature*Units.K,
                          equilibration_frequency = 1)
doVelocityVerletSteps(delta_t = 1.*Units.fs, nsteps = 500,
                      equilibration_temperature = 300*Units.K,
                      equilibration_frequency = 10)
# Production run
doVelocityVerletSteps(delta_t = 1.*Units.fs, nsteps = 5000)
trajectory.close()


