# -*- coding: utf-8 -*-
"""
Created on Monday, October 24, 2016

@author: Pam Cutter

This program contians functions that implement the K-Means
Clustering algorithm on 2-dimensional grade data.
"""

import math
import random
import numpy as np
import matplotlib.pyplot as plt
 
# practice using where
def testWhereFcn():
    # Create an array, a, of 10 positive and negative integers
    # print your array a

    # Create a tuple, b, that consists of the indices of all of the positive
    # integers in a, using np.where
    # print the tuple b
    
    # Assign all of the positive elements in a to 0, using the tuple b.
    # (this should be one, short line of code)

    # print your array a
    print(a)

''' This function computes the Euclidean distance between two points 
PARAMS:
    pt1: a tuple/list/array of size 1x2
    pt2: a tuple/list/array of size 1x2    
'''
def distance(pt1,pt2):
    d = 0   # replace the 0 with an actual distance formula
    return d
    
''' This function chooses k random elements from the data array
to be the initial centers of the k clusters.
PARAMS:
    data_array: the two-dimensional array of data
    k: the number of clusters
'''
def initClusterCenters(data_array,k):
    centers = np.zeros((k,2))   # create k cluster centers    
    for i in range(k):      # for each cluster, choose a random integer in the
        ind = random.randint(0,len(data_array)-1)   # range of possible indices
        cpt = (data_array[ind,0],data_array[ind,1])   # set the center point
        while cpt in centers:       # make sure that point isn't already used as a center
            ind = random.randint(0,len(data_array)-1)
            cpt = (data_array[ind,0],data_array[ind,1])
        centers[i] = (data_array[ind,0],data_array[ind,1])  # set the point in the centers array
    return centers
    
''' Determine which cluster an individual pt should be in
PARAMS:
    centers: the list of cluster centers
    pt: the point to determine the cluster for
'''
def labelForPt(centers,pt):
    minD = distance(centers[0],pt)   
    ind = 0
    for i in range(len(centers)):
        c = centers[i]
        # ADD CODE HERE TO DO THE FOLLOWING:
        # calculate the distance from pt to c
        # if that distance is less than the minimum,
        #   replace the minimum distance and the corresponding index
    return ind
          
''' This function determines which clusters all of the points in the 
data array belong to.
PARAMS:
    data_array: the two-dimensional array of data
    centers: the array of points representing the cluster centers
'''
def determineLabels(data_array, centers):
    # replace the following assignment of labels with something more useful
    # you may do this in one line, or several lines
    # your code should involve the following call, presuming i goes 
    # over all indices of data_array
    #        labelForPt(centers,(data_array[i,0],data_array[i,1]))    
    labels = np.zeros(len(data_array))
    
    return labels
    
''' Fill in the comments to explain what this function does '''
def extractCluster(dataArray,labels,clusterNum):  
    indices = np.where(labels == clusterNum)
    cluster = dataArray[indices]
    return cluster

''' This function recalculates the centers of the clusters, based 
on the mean vlaues of the elements in the clusters.
PARAMS:
    centers: the current cluster centers
    clusterList: the list of clusters (sets of pts belonging to each cluster)
'''
def recalculateCenters(centers,clusterList):
    for i in range(len(clusterList)):
        c = clusterList[i]
        # Calculate the mean of the elements in the first column of c
        # calculate the mean of the elements in the second column of c        
        
        centers[i] = (0,0) # should be (mean of first column,mean of second column)
    return centers

''' This is the main function that runs the K-Mean algorithm'''
def main():
    class_data = np.loadtxt(r'class-grades.csv', delimiter=',')  # or replace with your filename
    grade_pairs = class_data[:,1:6:4]       # slice to use grade columns 1 and 5 (hw and final exam grades)

    #  Ask user how many clusters    
    k = int(input("How many clusters would you like to see? "))
    centers = initClusterCenters(grade_pairs,k)
    print("initial centers:",centers)

    # Ask user how many times to iterate
    numIterations = int(input("How many iterations? "))
    for n in range(numIterations):
        # Determine which cluster the points belong to
        labels = determineLabels(grade_pairs,centers)

        plt.figure()        
        clusterList = []
        colorList=['red','blue','green', 'yellow','magenta','pink']
        for i in range(k):
            # Get the actual sets of points
            c = extractCluster(grade_pairs,labels,i)
            clusterList.append(c)
            #print(c)
            # Plot the clusters
            plt.axis([20,110,20, 120])
            plt.scatter(c[:,0],c[:,1],color=colorList[i])    #np.random.rand(3,1))
            for j in range(len(centers)):
                plt.plot(centers[j,0],centers[j,1],'k^')
       # recalculate the centers    
        centers = recalculateCenters(centers,clusterList)
        print("new centers:",centers)


