#define CUDATOOLS_IMPLEMENTATION
#define CUDATOOLS_ARRAY_MAX_AXES 8
#include <Array.h>
#include <BLAS.h>
#include <Core.h>
#include <Types.h>

#include <Eigen/Core>
#include <chrono>
#include <complex>

using namespace CudaTools::Types;
namespace CT = CudaTools;

/////////////
// Helpers //
/////////////

#define TIME_START(name) auto begin_##name = std::chrono::steady_clock::now()

#define TIME_END(name)                                                                             \
    auto end_##name = std::chrono::steady_clock::now();                                            \
    auto time_ms_##name =                                                                          \
        std::chrono::duration_cast<std::chrono::milliseconds>(end_##name - begin_##name).count();  \
    auto time_mus_##name =                                                                         \
        std::chrono::duration_cast<std::chrono::microseconds>(end_##name - begin_##name).count();  \
    if (time_ms_##name == 0) {                                                                     \
        printf("[%s] Time Elapsed: %ld[µs]\n", #name, time_mus_##name);                            \
    } else {                                                                                       \
        printf("[%s] Time Elapsed: %ld[ms]\n", #name, time_ms_##name);                             \
    }

#define TIME(call, name)                                                                           \
    TIME_START(name);                                                                              \
    call;                                                                                          \
    TIME_END(name);

#define TEST(predicate, name, msg)                                                                 \
    failed += (predicate) ? 0 : 1;                                                                 \
    printf("[%s] ", (predicate) ? "\033[1;32mPASS\033[0m" : "\033[1;31mFAIL\033[0m");              \
    printf("%s | %s.\n", name, msg);

template <typename T> struct Type;

#define REGISTER_PARSE_TYPE(X)                                                                     \
    template <> struct Type<X> { static const std::string name; };                                 \
    const std::string Type<X>::name = #X

REGISTER_PARSE_TYPE(uint8_t);
REGISTER_PARSE_TYPE(int16_t);
REGISTER_PARSE_TYPE(int32_t);
REGISTER_PARSE_TYPE(real32);
REGISTER_PARSE_TYPE(real64);
REGISTER_PARSE_TYPE(complex64);
REGISTER_PARSE_TYPE(complex128);

std::string box(std::string str) {
    std::string tops(str.size() + 6, '#');
    return tops + "\n## " + str + " ##\n" + tops;
}

std::string box2(std::string str) {
    std::string tops(str.size() - 5, '-');
    return tops + "\n|| " + str + " ||\n" + tops;
}

std::string boxSmall(std::string str) {
    std::string tops(6, '-');
    return tops + "[ " + str + " ]" + tops;
}

std::string separator() {
    std::string line(40, '=');
    return "\n" + line + "\n";
}

template <typename T> std::string type() { return "\033[1;96m" + Type<T>::name + "\033[0m"; }

CT::Shape makeRandom2DShape() {
    std::random_device rd;
    std::mt19937 mt(rd());
    std::uniform_int_distribution<uint32_t> dist(1, 15);
    return CT::Shape({dist(mt), dist(mt)});
}

///////////
// Tests //
///////////

class TestClass {
    DEVICE_COPY(TestClass);

  public:
    int x;
    TestClass(const int x) : x(x) {
        allocateDevice();
        updateDevice().wait();
    };
};

KERNEL(times, const CT::Array<int> arr) {
    BASIC_LOOP(arr.shape().length()) { arr[iThread] *= 2; }
}

KERNEL(classTest, TestClass* const test) { test->x = 100; }

KERNEL(collatz, const CT::Array<uint32_t> arr) {
    BASIC_LOOP(arr.shape().length()) {
        if (arr[iThread] % 2) {
            arr[iThread] = 3 * arr[iThread] + 1;
        } else {
            arr[iThread] = arr[iThread] >> 1;
        }
    }
}

KERNEL(plusOne, const CT::Array<uint32_t> arr) {
    BASIC_LOOP(arr.shape().length()) { arr[iThread] += 1; }
}

KERNEL(addArray, const CT::Array<uint32_t> a, const CT::Array<uint32_t> b) {
    BASIC_LOOP(a.shape().length()) { a[iThread] += b[iThread]; }
}

struct MacroTests {
    static uint32_t Kernel() {
        uint32_t failed = 0;
        CT::Array<int> A = CT::Array<int>::constant({10}, 1);
        A.updateDevice().wait();
        CT::Kernel::launch(times, CT::Kernel::basic(A.shape().items()), A.view()).wait();
        A.updateHost().wait();

        uint32_t errors = 0;
        for (auto it = A.begin(); it != A.end(); ++it) {
            if (*it != 2) ++errors;
        }

        std::ostringstream msg;
        msg << "Errors: " << errors;
        TEST(errors == 0, "Kernel", msg.str().c_str());
        return failed;
    };

    static uint32_t Class() {
        uint32_t failed = 0;
        TestClass test(1);
        CT::Kernel::launch(classTest, CT::Kernel::basic(1), test.that()).wait();
        test.updateHost().wait();

        TEST(test.x == 100, "Class", "Errors: 0");
        return failed;
    }
};

template <typename T> struct ArrayTests {
    static uint32_t Indexing() {
        uint32_t failed = 0;
        CT::Array<T> A = CT::Array<T>::range(0, 240);
        A.reshape({5, 3, 1, 4, 2, 1, 1, 2});

        uint32_t errors = 0;
        for (uint32_t i = 0; i < 5; ++i) {
            for (uint32_t j = 0; j < 3; ++j) {
                for (uint32_t k = 0; k < 4; ++k) {
                    for (uint32_t l = 0; l < 2; ++l) {
                        for (uint32_t m = 0; m < 2; ++m) {
                            if ((T)A[i][j][0][k][l][0][0][m] != (T)A[{i, j, 0, k, l, 0, 0, m}]) {
                                ++errors;
                            }
                        }
                    }
                }
            }
        }

        std::ostringstream msg;
        msg << "Errors: " << errors;
        TEST(errors == 0, "Element", msg.str().c_str());

        errors = 0;
        CT::Array<T> ApartGroup_1 = A[{2, 2}];
        CT::Array<T> ApartIndiv_1 = A[2][2];
        for (uint32_t k = 0; k < 4; ++k) {
            for (uint32_t l = 0; l < 2; ++l) {
                for (uint32_t m = 0; m < 2; ++m) {
                    if ((T)ApartIndiv_1[0][k][l][0][0][m] != (T)ApartGroup_1[{0, k, l, 0, 0, m}]) {
                        ++errors;
                    }
                }
            }
        }

        msg.str("");
        msg << "Errors: " << errors;
        TEST(errors == 0, "Axis (1/2)", msg.str().c_str());

        errors = 0;
        CT::Array<T> ApartGroup_2 = A[{3, 2, 0, 3}];
        CT::Array<T> ApartIndiv_2 = A[3][2][0][3];

        for (uint32_t l = 0; l < 2; ++l) {
            for (uint32_t m = 0; m < 2; ++m) {
                if ((T)ApartIndiv_2[l][0][0][m] != (T)ApartGroup_2[{l, 0, 0, m}]) {
                    ++errors;
                }
            }
        }

        msg.str("");
        msg << "Errors: " << errors;
        TEST(errors == 0, "Axis (2/2)", msg.str().c_str());
        return failed;
    };

    static uint32_t Slicing() {
        uint32_t failed = 0;
        CT::Array<T> A = CT::Array<T>::constant({4, 5, 5}, 0);

        CT::Array<T> Aslice = A.slice({{0, 3}, {1, 4}, {1, 4}});
        T num = (T)1;
        for (auto it = Aslice.begin(); it != Aslice.end(); ++it) {
            *it = num;
            ++num;
        }

        CT::Array<T> Aslice2 = A[3].slice({{0, 5}, {0, 1}});
        num = (T)-1;
        for (auto it = Aslice2.begin(); it != Aslice2.end(); ++it) {
            *it = num;
            --num;
        }

        uint32_t errors = 0;
        for (int i = 0; i < 3; ++i) {
            for (int j = 0; j < 3; ++j) {
                for (int k = 0; k < 3; ++k) {
                    if ((T)A[i][1 + j][1 + k] != (T)(9 * i + 3 * j + k + 1)) {
                        ++errors;
                    }
                }
            }
        }
        std::ostringstream msg;
        msg << "Errors: " << errors;
        TEST(errors == 0, "Block", msg.str().c_str());

        errors = 0;
        for (int i = 0; i < 5; ++i) {
            if ((T)A[3][i][0] != (T)(-(i + 1))) {
                ++errors;
            }
        }

        msg.str("");
        msg << "Errors: " << errors;
        TEST(errors == 0, "Column", msg.str().c_str());
        return failed;
    }
};

template <typename T> struct BLASTests {
    static double thres;
    static uint32_t GEMV(int attempts) {
        uint32_t failed = 0;
        for (int i = 0; i < attempts; i++) {
            CT::Shape Ashape = makeRandom2DShape();
            CT::Shape xshape = CT::Shape({Ashape.cols(), 1});
            CT::Shape yshape = CT::Shape({Ashape.rows(), 1});

            CT::Array<T> A(Ashape);
            CT::Array<T> x(xshape);
            CT::Array<T> y(yshape);

            A.setRandom(-100, 100);
            x.setRandom(-100, 100);

            A.updateDevice();
            x.updateDevice().wait();

            CT::BLAS::GEMV<T>(1.0, A, x, 0.0, y).wait();
            y.updateHost().wait();

            CT::Array<T> yTest(yshape, true);
            yTest.eigenMap() = A.eigenMap() * x.eigenMap();

            double norm = (y.eigenMap() - y.eigenMap()).norm();

            std::ostringstream name;
            name << "GEMV (" << i + 1 << "/" << attempts << ")";
            std::ostringstream msg;
            msg << "Matrix Shape: " << Ashape << ", "
                << "Residual: " << norm;
            TEST(norm < thres, name.str().c_str(), msg.str().c_str());
        }
        return failed;
    };

    static uint32_t GEMVBroadcast() {
        uint32_t failed = 0;
        CT::Shape Ashape = makeRandom2DShape();
        CT::Shape xshape = CT::Shape({Ashape.cols(), 1});
        CT::Shape yshape = CT::Shape({Ashape.rows(), 1});

        CT::Array<T> A({2, 3, Ashape.rows(), Ashape.cols()});
        CT::Array<T> x({2, 3, xshape.rows(), xshape.cols()});
        CT::Array<T> y({2, 3, yshape.rows(), yshape.cols()});

        A.setRandom(-100, 100);
        x.setRandom(-100, 100);

        A.updateDevice();
        x.updateDevice().wait();

        CT::BLAS::GEMV<T>(1.0, A, x, 0.0, y).wait();
        y.updateHost().wait();

        double norm = 0;
        CT::Array<T> yTest(yshape, true);
        for (int i = 0; i < 2; ++i) {
            for (int j = 0; j < 3; ++j) {
                yTest.eigenMap() = A[i][j].eigenMap() * x[i][j].eigenMap();
                norm += (yTest.eigenMap() - y[i][j].eigenMap()).norm();
            }
        }

        std::ostringstream msg;
        msg << "Matrix Shape: " << Ashape << ", "
            << "Residual: " << norm;
        TEST(norm < thres, "GEMV Broadcast", msg.str().c_str());
        return failed;
    };

    static uint32_t GEMM(int attempts) {
        uint32_t failed = 0;
        for (int i = 0; i < attempts; i++) {
            CT::Shape Ashape = makeRandom2DShape();
            CT::Shape Bshape = makeRandom2DShape();
            Bshape = CT::Shape({Ashape.cols(), Bshape.cols()});

            CT::Shape Cshape = CT::Shape({Ashape.rows(), Bshape.cols()});

            CT::Array<T> A(Ashape);
            CT::Array<T> B(Bshape);
            CT::Array<T> C(Cshape);

            A.setRandom(-100, 100);
            B.setRandom(-100, 100);
            C.setRandom(-100, 100);

            A.updateDevice();
            B.updateDevice();
            C.updateDevice().wait();

            CT::BLAS::GEMM<T>(1.0, A, B, 0.0, C).wait();
            C.updateHost().wait();

            CT::Array<T> CTest(Cshape, true);
            CTest.eigenMap() = A.eigenMap() * B.eigenMap();

            double norm = (CTest.eigenMap() - C.eigenMap()).norm();

            std::ostringstream name;
            name << "GEMM (" << i + 1 << "/" << attempts << ")";
            std::ostringstream msg;
            msg << "Matrix Shapes: " << Ashape << Bshape << ", "
                << "Residual: " << norm;
            TEST(norm < thres, name.str().c_str(), msg.str().c_str());
        }
        return failed;
    };

    static uint32_t GEMMBroadcast() {
        uint32_t failed = 0;
        CT::Shape Ashape = makeRandom2DShape();
        CT::Shape Bshape = makeRandom2DShape();
        Bshape = CT::Shape({Ashape.cols(), Bshape.cols()});

        CT::Shape Cshape = CT::Shape({Ashape.rows(), Bshape.cols()});

        CT::Array<T> A({2, 3, Ashape.rows(), Ashape.cols()});
        CT::Array<T> B({2, 3, Bshape.rows(), Bshape.cols()});
        CT::Array<T> C({2, 3, Cshape.rows(), Cshape.cols()});

        A.setRandom(-100, 100);
        B.setRandom(-100, 100);

        A.updateDevice();
        B.updateDevice();
        C.updateDevice().wait();

        CT::BLAS::GEMM<T>(1.0, A, B, 0.0, C).wait();
        C.updateHost().wait();

        double norm = 0;
        CT::Array<T> CTest(Cshape, true);
        for (int i = 0; i < 2; ++i) {
            for (int j = 0; j < 3; ++j) {
                CTest.eigenMap() = A[i][j].eigenMap() * B[i][j].eigenMap();
                norm += (CTest.eigenMap() - C[i][j].eigenMap()).norm();
            }
        }

        std::ostringstream msg;
        msg << "Matrix Shapes: " << Ashape << Bshape << ", "
            << "Residual: " << norm;
        TEST(norm < thres, "GEMM Broadcast", msg.str().c_str());
        return failed;
    };

    static uint32_t PLU() {
        uint32_t failed = 0;
        CT::Shape Ashape = makeRandom2DShape();
        CT::Shape xshape = makeRandom2DShape();
        Ashape = CT::Shape({Ashape.rows(), Ashape.rows()});
        xshape = CT::Shape({Ashape.rows(), xshape.cols()});

        CT::Array<T> A({2, 3, Ashape.rows(), Ashape.rows()});
        CT::Array<T> x({2, 3, xshape.rows(), xshape.cols()});
        CT::Array<T> b({2, 3, xshape.rows(), xshape.cols()});
        CT::Array<T> Ax({2, 3, xshape.rows(), xshape.cols()});

        A.setRandom(-100, 100);
        b.setRandom(-100, 100);

        CT::Array<T> LU(A.copy());
        x = b;

        A.updateDevice();
        LU.updateDevice();
        x.updateDevice().wait();

        CT::BLAS::PLUBatch<T> luBatch(LU);
        CT::BLAS::Batch<T> xBatch(x);
        luBatch.computeLU().wait();
        luBatch.solve(xBatch).wait();

        // Compute Ax and compare difference.
        CT::BLAS::GEMM<T>(1.0, A, x, 0.0, Ax).wait();
        Ax.updateHost();

        double norm = 0;
        for (int i = 0; i < 2; ++i) {
            for (int j = 0; j < 3; ++j) {
                norm += (Ax[i][j].eigenMap() - b[i][j].eigenMap()).norm();
            }
        }

        std::ostringstream msg;
        msg << "Matrix Shape: " << Ashape << xshape << ", "
            << "Residual: " << norm;
        TEST(norm < thres, "PLU/Solve", msg.str().c_str());
        return failed;
    }
};

template <> double BLASTests<float>::thres = 10e-1;
template <> double BLASTests<double>::thres = 10e-8;
template <> double BLASTests<complex64>::thres = 10e-1;
template <> double BLASTests<complex128>::thres = 10e-8;

uint32_t doMacroTests() {
    uint32_t failed = 0;
    failed += MacroTests::Kernel();
    failed += MacroTests::Class();
    std::cout << "\n";
    return failed;
}

template <typename T> uint32_t doArrayTests() {
    uint32_t failed = 0;
    std::cout << boxSmall("Index Tests : " + type<T>()) << "\n";
    failed += ArrayTests<T>::Indexing();
    std::cout << "\n" << boxSmall("Slice Tests : " + type<T>()) << "\n";
    failed += ArrayTests<T>::Slicing();
    std::cout << "\n";
    return failed;
}

template <typename T> uint32_t doBLASTests() {
    uint32_t failed = 0;
    std::cout << boxSmall("GEMV Tests : " + type<T>()) << "\n";
    failed += BLASTests<T>::GEMV(5);
    failed += BLASTests<T>::GEMVBroadcast();

    std::cout << "\n" << boxSmall("GEMM Tests : " + type<T>()) << "\n";
    failed += BLASTests<T>::GEMM(5);
    failed += BLASTests<T>::GEMMBroadcast();

    std::cout << "\n" << boxSmall("PLU Tests : " + type<T>()) << "\n";
    failed += BLASTests<T>::PLU();
    std::cout << "\n";
    return failed;
}

void addNum(const CT::Array<uint32_t> A, uint32_t num) {
    auto Aeig = A.atLeast2D().eigenMap();
    Aeig = Aeig.array() + num;
}

void myGraph(CT::GraphManager* gm, const CT::Array<uint32_t> A, const CT::Array<uint32_t> B) {
    // tools->launchHostFunction("graphStream", myHostFunc, A->view(), 5);
    A.updateDevice("graphStream");
    gm->makeBranch("graphStream", "graphStreamBranch");
    B.updateDevice("graphStreamBranch");
    for (uint32_t iTimes = 0; iTimes < 30; ++iTimes) {
        CT::Kernel::launch(collatz, CT::Kernel::basic(A.shape().items(), "graphStream"), A.view());
        CT::Kernel::launch(plusOne, CT::Kernel::basic(A.shape().items(), "graphStreamBranch"),
                           B.view());
    }

    gm->joinBranch("graphStream", "graphStreamBranch");
    CT::Kernel::launch(addArray, CT::Kernel::basic(A.shape().items(), "graphStream"), A.view(),
                       B.view());
    A.updateHost("graphStream");
    B.updateHost("graphStream");
    gm->launchHostFunction("graphStream", addNum, A.view(), 5);
}

uint32_t doGraphTest() {
    uint32_t failed = 0;
    CT::Array<uint32_t> A = CT::Array<uint32_t>::constant({1000000}, 50);
    CT::Array<uint32_t> B = CT::Array<uint32_t>::constant({1000000}, 0);
    CT::Manager::get()->addStream("graphStream");
    CT::Manager::get()->addStream("graphStreamBranch");

    CT::GraphManager gm;
    CT::Graph graph("graphStream", myGraph, &gm, A.view(), B.view());
    graph.execute().wait();

    uint32_t errors = 0;
    for (auto it = A.begin(); it != A.end(); ++it) {
        if (*it != 36) ++errors;
    }

    std::ostringstream msg;
    msg << "Errors: " << errors;
    TEST(errors == 0, "Graph", msg.str().c_str());
    return failed;
}

int main() {
    uint32_t failed = 0;
    std::cout << box("Macro Tests") << "\n";
    failed += doMacroTests();

    std::cout << box("Array Tests") << "\n";
    // Test different sizes.
    failed += doArrayTests<uint8_t>();
    failed += doArrayTests<int16_t>();
    failed += doArrayTests<int32_t>();
    failed += doArrayTests<real64>();

    std::cout << box("BLAS Tests") << "\n";
    failed += doBLASTests<real32>();
    failed += doBLASTests<real64>();
    failed += doBLASTests<complex64>();
    failed += doBLASTests<complex128>();

    std::cout << box("Stream/Graph Tests") << "\n";
    failed += doGraphTest();

    constexpr uint32_t tests = 2 + 4 * 5 + 13 * 4 + 1;
    std::ostringstream msg;
    msg << ((failed == 0) ? "\033[1;32mPASS \033[0m(" : "\033[1;31mFAIL \033[0m(")
        << (tests - failed) << "/" << tests << ")";
    std::cout << box2(msg.str()) << "\n";

    return 0;
}