from mpi4py import MPI
import domain
import numpy as np

N = 0
E = 1
S = 2
W = 3


class Parallel:
    # rang du processus
    rang = 0
    # Nombre de processus
    nbprocs = 1
    # Topologie cartesiennne
    comm2d = MPI.COMM_WORLD
    # Nombre de processus dans chaque dimension
    dims = [0, 0]
    # rang des 4 voisins
    voisin = [0]*4
    # Type pour une ligne
    type_ligne = MPI.DOUBLE
    # Type pour une colonne
    type_colonne = MPI.DOUBLE

    def env_init(self):
        # Recuperation du rang et du nombre de processus -> rang, nbprocs
        self.rang = self.comm2d.Get_rank()
        self.nbprocs = self.comm2d.Get_size()

    def topology_init(self, d):
        # Copie d dans self
        self.d = d
        # Lecture ntx et nty depuis le fichier poisson.data
        with open("poisson.data", "r") as file:
            d.ntx = int(file.readline())
            d.nty = int(file.readline())
        # print(f"rang {rang}: ntx={ntx}, nty={nty}")
        # Nombre de processus dans chaque dimensions -> self.dims
        self.dims = MPI.Compute_dims(self.nbprocs, self.dims)
        # Creation de la topology 2D -> self.comm2d
        self.comm2d = MPI.COMM_WORLD.Create_cart(self.dims)
        if self.rang == 0:
            print("Lancement de poisson avec", self.nbprocs, "processus MPI")
            print("Taille du domaine: ntx="+str(d.ntx)+" nty="+str(d.nty))
            print("Dimension de la topologie:", self.dims[0],
                  "suivant x,", self.dims[1], "suivant y")
            print("-----------------------------------------")

    def domain_boundaries(self):
        # Coordonnees dans la topologie -> coords
        coords = self.comm2d.Get_coords(self.rang)
        # Calcule les limites local en X
        self.d.sx = (coords[0]*self.d.ntx)//self.dims[0]+1
        self.d.ex = ((coords[0]+1)*self.d.ntx)//self.dims[0]
        # Calcule les limites local en Y
        self.d.sy = (coords[1]*self.d.nty)//self.dims[1]+1
        self.d.ey = ((coords[1]+1)*self.d.nty)//self.dims[1]
        print(f"rang dans la topologie: {self.rang} Indice des tableaux: "
              f"{self.d.sx} a {self.d.ex} suivant x, "
              f"{self.d.sy} a {self.d.ey} suivant y")

    def domain_neighbours(self):
        # Calcul des voisins Nord et Sud -> self.voisin[N], self.voisin[S]
        self.voisin[N], self.voisin[S] = self.comm2d.Shift(0, 1)
        # Calcul des voisins Ouest et Est -> self.voisin[W], self.voisin[E]
        self.voisin[W], self.voisin[E] = self.comm2d.Shift(1, 1)
        print(f"Processus {self.rang} a pour voisin: "
              f"N {self.voisin[N]} E {self.voisin[E]} "
              f"S {self.voisin[S]} W {self.voisin[W]}")

    def derived_datatypes(self):
        db = MPI.DOUBLE
        # Creation du type ligne -> self.type_ligne
        self.type_ligne = db.Create_contiguous(self.d.ey - self.d.sy + 1)
        self.type_ligne.Commit()
        # Creation du type colonne -> self.type_colonne
        self.type_colonne = db.Create_vector(self.d.ex - self.d.sx + 1, 1,
                                             self.d.ey - self.d.sy + 3)
        self.type_colonne.Commit()

    def communication(self):
        # sx, sy, ex, ey = local indice in self.d.u
        sx = 1
        sy = 1
        ex = self.d.ex - self.d.sx + 1
        ey = self.d.ey - self.d.sy + 1
        # Echange des points aux interfaces
        # Envoi au voisin N et reception du voisin S
        self.comm2d.Sendrecv(sendbuf=[self.d.u[sx, sy:], 1, self.type_ligne],
                             dest=self.voisin[N],
                             recvbuf=[self.d.u[ex + 1, sy:],
                                      1, self.type_ligne],
                             source=self.voisin[S])
        # Envoi au voisin S et reception du voisin N
        self.comm2d.Sendrecv(sendbuf=[self.d.u[ex, sy:], 1, self.type_ligne],
                             dest=self.voisin[S],
                             recvbuf=[self.d.u[sx - 1, sy:],
                                      1, self.type_ligne],
                             source=self.voisin[N])
        # Envoi au voisin W et reception du voisin E
        self.comm2d.Sendrecv(sendbuf=[self.d.u[sx, sy:], 1, self.type_colonne],
                             dest=self.voisin[W],
                             recvbuf=[self.d.u[sx, ey + 1:],
                                      1, self.type_colonne],
                             source=self.voisin[E])
        # Envoi au voisin E et  reception du voisin W
        self.comm2d.Sendrecv(sendbuf=[self.d.u[sx, ey:], 1, self.type_colonne],
                             dest=self.voisin[E],
                             recvbuf=[self.d.u[sx, sy - 1:],
                                      1, self.type_colonne],
                             source=self.voisin[W])

    def global_error(self):
        # Calcul de l'erreur globale (maximum des erreurs locales)
        local_erreur = 0.
        for iterx in range(self.d.sx, self.d.ex + 1):
            for itery in range(self.d.sy, self.d.ey + 1):
                dx = iterx - self.d.sx + 1
                dy = itery - self.d.sy + 1
                delta = np.abs(self.d.u[dx, dy] - self.d.u_new[dx, dy])
                local_erreur = max(local_erreur, delta)
        # Calcul de l'erreur sur tous les sous-domaines
        erreur = self.comm2d.allreduce(local_erreur, op=MPI.MAX)
        return erreur

    def write_data(self):
        # Ouverture du fichier "data.dat" en écriture
        fh = MPI.File.Open(self.comm2d, "data.dat",
                           MPI.MODE_WRONLY | MPI.MODE_CREATE)
        # Creation du type type_sous_tab_vue pour la vue sur le fichier
        shape_array_view = [self.d.ntx, self.d.nty]
        sizex = self.d.ex - self.d.sx + 1
        sizey = self.d.ey - self.d.sy + 1
        shape_subarray_view = [sizex, sizey]
        start_coord_view = [self.d.sx - 1, self.d.sy - 1]
        db = MPI.DOUBLE
        type_subarray_view = db.Create_subarray(shape_array_view,
                                                shape_subarray_view,
                                                start_coord_view,
                                                MPI.ORDER_C)
        type_subarray_view.Commit()
        # Définition de la vue sur le fichier a partir du debut
        fh.Set_view(0, db, type_subarray_view)
        # Creation du type derive type_sous_tab correspondant a la matrice u
        # sans les cellules fantomes
        shape_array = [sizex + 2, sizey + 2]
        shape_subarray = [sizex, sizey]
        start_coord = [1, 1]
        type_subarray = db.Create_subarray(shape_array, shape_subarray,
                                           start_coord, MPI.ORDER_C)
        type_subarray.Commit()
        # Ecriture du tableau u par tous les processus avec la vue
        fh.Write_all([self.d.u, 1, type_subarray])
        # Fermeture du fichier
        fh.Close()
