import numpy as np
from mpi4py import MPI

tag = 1000
# Define the communicator
comm = MPI.COMM_WORLD

# Initialize the number of processes -> rank nprocs
nprocs = comm.Get_size()
rank = comm.Get_rank()

n = nprocs
if rank == 0:
    # Read the matrix size N from a file
    with open("matrix_products.data", "r") as file:
        n = int(file.readline().strip())

# Broadcast the matrix size N to all processes
n = comm.bcast(n, root=0)
# print(f"After bcast n={n} on {rank}")

# N need to be divisible by nprocs
if (n % nprocs) != 0:
    print("N is not divisible by Nprocs")
    # Stop the program if N is not divisible by Nprocs
    comm.Abort()
else:
    nl = n//nprocs

if rank == 0:
    A = np.random.rand(n, n)
    B = np.random.rand(n, n)
    C = np.zeros((n, n), dtype=np.float64)
    CC = np.zeros((n, n), dtype=np.float64)
    # Sequential computation of A*B
    np.matmul(A, B, out=C)
else:
    # Need to be defined to avoid not defined error in scatter & gather
    A = None
    B = None
    C = None
    CC = None

# Initialize the local arrays
AL = np.empty((nl, n), dtype=np.float64)
BL = np.empty((n, nl), dtype=np.float64)
CL = np.empty((n, nl), dtype=np.float64)
TEMP = np.empty((nl, n), dtype=np.float64)

# Build the datatype for 1 chunk of N lines and NL columns
type_temp = MPI.DOUBLE.Create_vector(n, nl, n)
extent = MPI.DOUBLE.Get_size() * nl
type_slice = type_temp.Create_resized(0, extent)
type_slice.Commit()

# Scatter A to AL and B to BL
comm.Scatter(sendbuf=[A, n*nl, MPI.DOUBLE], recvbuf=[AL, n*nl, MPI.DOUBLE],
             root=0)
comm.Scatter([B, 1, type_slice], [BL, n*nl, MPI.DOUBLE], root=0)
# Compute the diagonal blocks
np.matmul(AL, BL, out=CL[rank*nl:(rank+1)*nl, :])
# Compute for non-diagonal blocks
# First algorithm
# for iter in range(nprocs):
#    if (rank != iter):
#        #print(f"rank {rank} iter {iter}")
#        # Send AL to process k and receive his AL in temp
#        comm.Sendrecv(sendbuf=[AL,n*nl,MPI.DOUBLE], dest=iter, sendtag=tag,
#                      recvbuf=[TEMP,n*nl,MPI.DOUBLE], source=iter, recvtag=tag)
#    # Compute the block above or below the diagonal block
#    np.matmul(TEMP, BL, out=CL[iter*nl:(iter+1)*nl,:])
# Second algorithm
previous_rank = (nprocs+rank-1) % nprocs
following_rank = (rank+1) % nprocs
for iter in range(1, nprocs):
    # Send AL to previous process and receive from following_rank
    comm.Sendrecv_replace(buf=[AL, n*nl, MPI.DOUBLE],
                          dest=previous_rank, sendtag=tag,
                          source=following_rank, recvtag=tag)
    # Compute the block above or below the diagonal block
    displacement = (rank+iter) % nprocs
    np.matmul(AL, BL, out=CL[displacement*nl:(displacement+1)*nl, :])
# Gather all CL slices to form the C matrix
comm.Gather([CL, n*nl, MPI.DOUBLE], [CC, 1, type_slice], root=0)
# Deallocate local arrays
del AL, BL, CL, TEMP
# Verify the results (only process 0 does this)
if rank == 0:
    Emax = 0.0
    Emax = np.max(np.abs(C-CC))
    print(f"Emax = {Emax}")
    if Emax < 1e-10:
        print("Super!")
        print("Matrix product A*B in parallel equal the sequential one")
    else:
        print("False result!")
        print("Matrix product A*B in parallel different from the sequential one")
