import math
import csv
import matplotlib.pyplot as plt
import numpy as np

def distance_euclidienne(xa,ya,xb,yb):
    d = math.sqrt((xa - xb)**2 + (ya-yb)**2)
    return d

def kNN(liste,x,y,k):
    distances = []
    for i in range(len(liste)):
        distances.append((distance_euclidienne(x,y,liste[i][0],liste[i][1]), liste[i][2]))
    distances_tri = sorted(distances)
    print(distances_tri)
    compte_varietes = {'Iris-setosa':0,'Iris-versicolor':0,'Iris-virginica':0}
    for j in range(k):
        for variete in compte_varietes:
            if distances_tri[j][1]== variete:
                compte_varietes[variete] += 1
    return max(compte_varietes, key=compte_varietes.get)

file = open("iris.csv", "r")
csv_reader = csv.reader(file)
next(csv_reader)

lists_from_csv = []
for row in csv_reader:
    lists_from_csv.append([float(row[i]) if i in [1,3] else row[i] for i in [1,3,4]])

print(lists_from_csv)
x,y,var = zip(*lists_from_csv)

def conditions(variete):
    if variete=='Iris-setosa' : return 'r'
    elif variete =='Iris-virginica': return 'b'
    elif variete =='Iris-versicolor': return 'g'

print( 'classe majoritaire: ', kNN(lists_from_csv,3.05,1.6,1))

def afficher_inconnu(xi,yi):
    c = [conditions(v) for v in var]
    plt.scatter(x,y,color=c)
    plt.scatter(xi,yi,color='black')
    plt.show()
    
afficher_inconnu(3.05,1.6)




