# -*- coding: utf-8 -*-
"""
Created on Tuesday November 6, 2018

@author: pcutter

Modified by: (Your name(s) here)
With Assistance from:  (Anyone who helped)

Ants.py
This module provides a computational model to examining a specific system that
is discrete time, discrete space, and model of ant movement and population dynamics.

This model will make the following simplifying assumptions.
1. The environment is divided in to a finite grid of discrete cells that can
contain at most one individual.
2. Individuals can move from one cell to another.
3. Individuals can only interact with their immediate neighbors.

"""

import numpy as np
import matplotlib.pyplot as plt
import random
import matplotlib as mpl

MAXPHER = 50.0
EMPTY = 0
NORTH = 1
EAST = 2
SOUTH = 3
WEST = 4
STAY = 5
BORDER = 6
EVAPORATE = 1
THRESHOLD = 0
DEPOSIT = 2

"""
    Display the environment as an image with multi colors.
    ##### UPDATE COMMENTS HERE TO DESCRIBE WHAT THE COLORS REPRESENT #########
    PARAM:
        env: the current environment
   """
def display_env(env):
    #plt.figure()     #include this line if you want to see each grid separately
    plt.imshow(env)
    plt.show()
    plt.pause(0.5)
    # print(env)   # this can be commented out when you know your code is working

""" An additional display function to experiment with """
def display_env2(env):
    plt.figure()
    cmap=mpl.colors.ListedColormap(['blue','red','green','orange','pink','purple','yellow'])
    bounds=[-.5,.5,0.6,1.5,1.6,2.3,2.4,3.3,3.4,4.4,4.5,5.5]
    norm=mpl.colors.BoundaryNorm(bounds, cmap.N)
    img = plt.imshow(env,interpolation='nearest',cmap=cmap, norm=norm)
    plt.colorbar(img,cmap=cmap,norm=norm,boundaries=bounds,ticks=[0,1,2,4,5,6])
    plt.show()


"""
    Counts the number of a given type of population that are in the 
    environment. For example counts and returns the number of
    predators in the environment.
    PARAMS:
        env: the two dimensional array.
        pop_type: a integer representing the type of population to count
    RETURNS:
        sum: the sum of all the 'pop_type' elements that are in the env
"""
def count_population(env, pop_type):
    count = 0
    for r in range(len(env)):
        for c in range(len(env[0])):
            if env[r][c] == pop_type:
                count += 1
    return count

"""
    Randomly initialize a 2d array to represent 
    ##### UPDATE COMMENTS HERE TO DESCRIBE HOW THE GRID IS INITIALIZED #######
    PARAMS:
        width: width of the environment(2D array).
        height: height of the environment(2D array).
        prop_prey: beginning proportions of prey.
        prop_pred: beginning proportions of predator.
    RETURNS:
        env - the initialized environment(2D array).
    """
def initAntGrid(n, probAnt):
    grid =np.zeros((n+2,n+2),dtype=int)
    
    ''' Initialize border cells to be BORDER '''
    grid[0,:] = BORDER
    grid[n+1,:] = BORDER
    grid[:,0] = BORDER
    grid[:,n+1] = BORDER
    
    if (probAnt) > 1:
        raise ValueError('Probability of an ant in any cell cannot be greater than 100%!')
    else:
        ''' Initialize interior cells '''
        for r in range(1,n+1):
            for c in range(1,n+1):
                x = random.random()
                if x < probAnt:         # only populate cell with specified probability
                    num = random.randint(1,4)
                    grid[r,c] = num
    print (grid)
    return grid

"""
    Randomly initialize a 2d array to represent 
    #### UPDATE COMMENTS HERE TO DESCRIBE WHAT THIS GRID REPRESENTS ######
    PARAMS:
        width: width of the environment(2D array).
        height: height of the environment(2D array).
        prop_prey: beginning proportions of prey.
        prop_pred: beginning proportions of predator.
    RETURNS:
        env - the initialized environment(2D array).
    """
def initPherGrid(n):
    grid =np.zeros((n+2,n+2))

    ''' Initialize border cells to be -0.01 '''
    grid[0,:] = -0.01
    grid[n+1,:] = -0.01
    grid[:,0] = -0.01
    grid[:,n+1] = -0.01
    
    ''' Initialize pher0mone trail in middle row of grid '''
    for i in range(1,n+1):
        grid[int(n/2),i] = MAXPHER*i/n
    print (grid)
    return grid

""" This function returns a list of the 8 phermone values for the neighbors of a cell"""
def getNeighborPhers(r, c, pherGrid):
    pherList = []
    for col in range(c-1,c+2):
        pherList.append(pherGrid[r-1][col])
        pherList.append(pherGrid[r+1][col])
    pherList.append(pherGrid[r][c-1])
    pherList.append(pherGrid[r][c+1])
    return pherList

""" returns neighbors in N,E,S,W pattern """
def get4Neighbors(r,c,grid):
    neighList = []
    neighList.append(grid[r-1][c])
    neighList.append(grid[r][c+1])
    neighList.append(grid[r+1][c])
    neighList.append(grid[r][c-1])
    return neighList
    
""" diffusion function ADD MORE DESCRIPTIVE COMMENTS HERE """
def diffusion(diffRate, sitePher, neighborPhers):
    return (1 - 8*diffRate)*sitePher + diffRate*(sum(neighborPhers))
    
""" ADD APPROPRIATE COMMENTS FOR THIS FUNCTION """
def applyDiffusionExtended(antGrid, pherGrid,diffRate):
    newPherGrid = pherGrid.copy()
    for r in range(1,len(antGrid)-1):
        for c in range(1,len(antGrid[0])-1):
            nPherList = getNeighborPhers(r,c,pherGrid)
            newPherGrid[r,c] = diffusion(diffRate, pherGrid[r,c],nPherList)
    return newPherGrid

""" ADD APPROPRIATE COMMENTS FOR THIS FUNCTION """
def sense(site, neighAntList, neighPherList):
    if site != STAY:
        neighPherList[int(site)-1] = -2
        # if a neighbor cell has an ant, change the corresponding list elt to -2
        for i in range(len(neighAntList)):
            if neighAntList[i] >=1 and neighAntList[i]<=4:
                neighPherList[i] = -2
        print (neighPherList)
        mx = max(neighPherList)
        if mx < 0:
            return STAY
        else:
            # return random cell with max phermone
            posList=[]
            for i in range(len(neighPherList)):
                if neighPherList[i] == mx:
                    posList.append(i)
            rndPos = random.randint(0,len(posList)-1)
            print (posList[rndPos]+1)
            return posList[rndPos]+1
    else:
        return STAY
""" ADD APPROPRIATE COMMENTS FOR THIS FUNCTION """        
def applySenseExtended(antGrid,pherGrid):
    newAntGrid = antGrid.copy()
    for r in range(1,len(antGrid)-1):
        for c in range(1,len(antGrid[0])-1):
            antList = get4Neighbors(r,c,antGrid)
            pherList = get4Neighbors(r,c,pherGrid)
            if antGrid[r,c] != EMPTY:
                newAntGrid[r,c] = sense(antGrid[r,c], antList,pherList)
    return newAntGrid

""" ADD APPROPRIATE COMMENTS FOR THIS FUNCTION """
def walk(antGrid, pherGrid):
    newAntGrid = antGrid.copy()
    newPherGrid = pherGrid.copy()
    for r in range(1,len(antGrid)-1):
        for c in range(1,len(antGrid[0])-1):
            if antGrid[r,c] == EMPTY:
                newPherGrid[r,c] = max(0,newPherGrid[r,c]-EVAPORATE)
            elif antGrid[r,c] == NORTH:
                if newAntGrid[r-1,c] == EMPTY:
                    if newPherGrid[r,c] > THRESHOLD:
                        newPherGrid[r,c] = newPherGrid[r,c]+DEPOSIT
                    newAntGrid[r,c] = EMPTY
                    newAntGrid[r-1,c] = SOUTH
                else:
                    newAntGrid[r,c] = STAY
            elif antGrid[r,c] == EAST:
                if newAntGrid[r,c+1] == EMPTY:
                    if newPherGrid[r,c] > THRESHOLD:
                        newPherGrid[r,c] = newPherGrid[r,c]+DEPOSIT
                    newAntGrid[r,c] = EMPTY
                    newAntGrid[r,c+1] = WEST
                else:
                    newAntGrid[r,c] = STAY
            elif antGrid[r,c] == SOUTH:
                if newAntGrid[r+1,c] == EMPTY:
                    if newPherGrid[r,c] > THRESHOLD:
                        newPherGrid[r,c] = newPherGrid[r,c]+DEPOSIT
                    newAntGrid[r,c] = EMPTY
                    newAntGrid[r+1,c] = NORTH
                else:
                    newAntGrid[r,c] = STAY
            elif antGrid[r,c] == WEST:
                if newAntGrid[r,c-1] == EMPTY:
                    if newPherGrid[r,c] > THRESHOLD:
                        newPherGrid[r,c] = newPherGrid[r,c]+DEPOSIT
                    newAntGrid[r,c] = EMPTY
                    newAntGrid[r,c-1] = EAST
                else:
                    newAntGrid[r,c] = STAY
    return newAntGrid, newPherGrid
            
""" ADD APPROPRIATE COMMENTS FOR THIS FUNCTION """              
def ants(n, probAnt, diffusionRate, t):
    # INITIALIZE the antGrid and pheromone grid
    antGrid = initAntGrid(n, probAnt)
    pherGrid = initPherGrid(n)
    antGridList = []
    antGridList.append(antGrid)
    pherGridList = []
    pherGridList.append(pherGrid)
    # simulate the ant movement over t number of steps
    for iteration in range(t):
        antGrid = applySenseExtended(antGrid,pherGrid)
        antGrid, pherGrid = walk(antGrid,pherGrid)
        pherGrid = applyDiffusionExtended(antGrid,pherGrid,diffusionRate)
        antGridList.append(antGrid)
        pherGridList.append(pherGrid)
    
    for a in antGridList:
        print (a)
        display_env(a)
    '''     
    print
    print"pherLists"
    print
    for p in pherGridList:
        print p
        '''
        
