from mpi4py import MPI
import domain
import numpy as np

N = 0
E = 1
S = 2
W = 3


class Parallel:
    # rang du processus
    rang = 0
    # Nombre de processus
    nbprocs = 1
    # Topologie cartesiennne
    comm2d = MPI.COMM_WORLD
    # Nombre de processus dans chaque dimension
    dims = [0, 0]
    # rang des 4 voisins
    voisin = [0]*4
    # Type pour une ligne
    type_ligne = MPI.DOUBLE
    # Type pour une colonne
    type_colonne = MPI.DOUBLE

    def env_init(self):
        # TODO: Recuperation du rang et du nombre de processus -> rang, nbprocs
    
  
    def topology_init(self, d):
        # Copie d dans self
        self.d = d
        # Lecture ntx et nty depuis le fichier poisson.data
        with open("poisson.data", "r") as file:
            d.ntx = int(file.readline())
            d.nty = int(file.readline())
        # print(f"rang {rang}: ntx={ntx}, nty={nty}")
        # TODO: Nombre de processus dans chaque dimensions -> self.dims
    
        # TODO: Creation de la topology 2D -> self.comm2d
    
        if self.rang == 0:
            print("Lancement de poisson avec", self.nbprocs, "processus MPI")
            print("Taille du domaine: ntx="+str(d.ntx)+" nty="+str(d.nty))
            print("Dimension de la topologie:", self.dims[0],
                  "suivant x,", self.dims[1], "suivant y")
            print("-----------------------------------------")

    def domain_boundaries(self):
        # TODO: Coordonnees dans la topologie -> coords
    
        # Calcule les limites local en X
        self.d.sx = (coords[0]*self.d.ntx)//self.dims[0]+1
        self.d.ex = ((coords[0]+1)*self.d.ntx)//self.dims[0]
        # Calcule les limites local en Y
        self.d.sy = (coords[1]*self.d.nty)//self.dims[1]+1
        self.d.ey = ((coords[1]+1)*self.d.nty)//self.dims[1]
        print(f"rang dans la topologie: {self.rang} Indice des tableaux: "
              f"{self.d.sx} a {self.d.ex} suivant x, "
              f"{self.d.sy} a {self.d.ey} suivant y")

    def domain_neighbours(self):
        # TODO: Calcul des voisins Nord et Sud -> self.voisin[N], self.voisin[S]
    
        # TODO: Calcul des voisins Ouest et Est -> self.voisin[W], self.voisin[E]
    
        print(f"Processus {self.rang} a pour voisin: "
              f"N {self.voisin[N]} E {self.voisin[E]} "
              f"S {self.voisin[S]} W {self.voisin[W]}")

    def derived_datatypes(self):
        # TODO: Creation du type ligne -> self.type_ligne
    
        # TODO: Creation du type colonne -> self.type_colonne
    
  
    def communication(self):
        # sx, sy, ex, ey = local indice in self.d.u
        sx = 1
        sy = 1
        ex = self.d.ex - self.d.sx + 1
        ey = self.d.ey - self.d.sy + 1
        # Echange des points aux interfaces
        # TODO: Envoi au voisin N et reception du voisin S
    
        # TODO: Envoi au voisin S et reception du voisin N
    
        # TODO: Envoi au voisin W et reception du voisin E
    
        # TODO: Envoi au voisin E et  reception du voisin W
    
  
    def global_error(self):
        # Calcul de l'erreur globale (maximum des erreurs locales)
        local_erreur = 0.
        for iterx in range(self.d.sx, self.d.ex + 1):
            for itery in range(self.d.sy, self.d.ey + 1):
                dx = iterx - self.d.sx + 1
                dy = itery - self.d.sy + 1
                delta = np.abs(self.d.u[dx, dy] - self.d.u_new[dx, dy])
                local_erreur = max(local_erreur, delta)
        # TODO: Calcul de l'erreur sur tous les sous-domaines -> erreur

        return erreur

    def write_data(self):
        # TODO: Ouverture du fichier "data.dat" en écriture
    
        # TODO: Creation du type type_sous_tab_vue pour la vue sur le fichier
    
        # TODO: Définition de la vue sur le fichier a partir du debut
    
        # TODO: Creation du type derive type_sous_tab correspondant a la matrice u
        # sans les cellules fantomes
    
        # TODO: Ecriture du tableau u par tous les processus avec la vue
    
        # TODO: Fermeture du fichier
    
