from pylab import *

import readsnap
import time

x = 0 # axis1
y = 1 # axis2

start = 0
end   = 20

nsnap = end-start+1

fig = figure()
ax = fig.add_subplot(111,aspect="equal")
ax.set_xlabel("x / kpc")
ax.set_ylabel("y / kpc")
ax.set_xlim(-130., 130.)
ax.set_ylim(-130., 130.)

c1,=ax.plot(0, 0, marker=",", linestyle="none", c="blue")
c2,=ax.plot(0, 0, marker=",", linestyle="none", c="red")
c3,=ax.plot(0, 0, marker=",", linestyle="none", c="yellow")
c4,=ax.plot(0, 0, marker=",", linestyle="none", c="green")
c5,=ax.plot(0, 0, marker="x", linestyle="none", c="black")
c6,=ax.plot(0, 0, marker="x", linestyle="none", c="black")

fig.show()


for i in range(start,end+1):

    # --- load data ---
    filename = '../Gadget-2.0.7/Output/Gal11/snapshot_%(number)03d' %{"number":i}

    # --- plot particles---
    header = readsnap.snapshot_header(filename)

    posdisk  = readsnap.read_block(filename, "POS ", parttype=2) # load positions of disk particles
    posbulge = readsnap.read_block(filename, "POS ", parttype=3) # load positions of bulge particles
    iddisk   = readsnap.read_block(filename, "ID  ", parttype=2) # load IDs of disk particles
    idbulge  = readsnap.read_block(filename, "ID  ", parttype=3) # load IDs of bulge particles

    ind1 = where(iddisk  <= 13000)[0] # select 1st galaxy
    ind2 = where(iddisk  >  13000)[0] # select 2nd galaxy
    ind3 = where(idbulge <= 13000)[0] # select 1st galaxy
    ind4 = where(idbulge >  13000)[0] # select 2nd galaxy

    c1.set_xdata(posdisk[ind1,x])
    c1.set_ydata(posdisk[ind1,y])
    c2.set_xdata(posdisk[ind2,x])
    c2.set_ydata(posdisk[ind2,y])
    c3.set_xdata(posbulge[ind3,x])
    c3.set_ydata(posbulge[ind3,y])
    c4.set_xdata(posbulge[ind4,x])
    c4.set_ydata(posbulge[ind4,y])

    # --- plot halos positions ---
    p1 = vstack((posdisk[ind1],posbulge[ind3]))
    p2 = vstack((posdisk[ind2],posbulge[ind4]))

    x1 = p1[:,x]
    y1 = p1[:,y]
    x2 = p2[:,x]
    y2 = p2[:,y]
    cmx1 = sum(x1)/len(x1)
    cmy1 = sum(y1)/len(y1)
    cmx2 = sum(x2)/len(x2)
    cmy2 = sum(y2)/len(y2)
    c5.set_xdata(cmx1)
    c5.set_ydata(cmy1)
    c6.set_xdata(cmx2)
    c6.set_ydata(cmy2)

    ax.set_xlim(-130., 130.)
    ax.set_ylim(-130., 130.)

    fig.canvas.draw()

    time.sleep(0.5)

