#include "../../extern/doctest.hpp"
#include "../../src/Parser/Parser.hpp"

using namespace MalachScript;

#define PARSER_TEST(name, tokens, asserts)                                                                             \
    TEST_CASE(name) {                                                                                                  \
        std::vector<Parser::LexToken*> vec = {                                                                         \
            tokens,                                                                                                    \
            new Parser::LexTokenImpl<Parser::LexTokenKind::EndOfFile>(TextSpan(0, 0)),                                 \
        };                                                                                                             \
        for (size_t i = 0; i < vec.size() - 1; i++) {                                                                  \
            vec[i]->SetNext(vec[i + 1]);                                                                               \
        }                                                                                                              \
        Diagnostics::Diagnostics diags;                                                                                \
        auto parser = Parser::Parser(u8"scriptname", vec.front(), &diags);                                             \
        auto* script = parser.Parse();                                                                                 \
        REQUIRE(diags.GetMessages().empty());                                                                          \
        asserts;                                                                                                       \
        delete vec[0];                                                                                                 \
        delete script;                                                                                                 \
    }

#define PARSER_TEST_TOKENS(...) __VA_ARGS__

PARSER_TEST(
    "Parse class foobar { bool foo { get; set; } }",
    PARSER_TEST_TOKENS(new Parser::LexTokenImpl<Parser::LexTokenKind::ClassKeyword>(TextSpan(0, 0)),
                       new Parser::IdentifierToken(TextSpan(0, 0), u8"foobar"),
                       new Parser::LexTokenImpl<Parser::LexTokenKind::OpenCurlyParenthesisSymbol>(TextSpan(0, 0)),
                       new Parser::IdentifierToken(TextSpan(0, 0), u8"bool"),
                       new Parser::IdentifierToken(TextSpan(0, 0), u8"foo"),
                       new Parser::LexTokenImpl<Parser::LexTokenKind::OpenCurlyParenthesisSymbol>(TextSpan(0, 0)),
                       new Parser::LexTokenImpl<Parser::LexTokenKind::GetKeyword>(TextSpan(0, 0)),
                       new Parser::LexTokenImpl<Parser::LexTokenKind::SemicolonSymbol>(TextSpan(0, 0)),
                       new Parser::LexTokenImpl<Parser::LexTokenKind::SetKeyword>(TextSpan(0, 0)),
                       new Parser::LexTokenImpl<Parser::LexTokenKind::SemicolonSymbol>(TextSpan(0, 0)),
                       new Parser::LexTokenImpl<Parser::LexTokenKind::CloseCurlyParenthesisSymbol>(TextSpan(0, 0)),
                       new Parser::LexTokenImpl<Parser::LexTokenKind::CloseCurlyParenthesisSymbol>(TextSpan(0, 0))),
    {
        REQUIRE(script->GetStatements().size() == 1);
        auto firstStatement = script->GetStatements()[0].get();
        REQUIRE(firstStatement->GetKind() == Parser::ParsedStatementKind::Class);
        auto firstClassStatement =
            dynamic_cast<const MalachScript::Parser::ParsedClassStatement*>(firstStatement)->GetBody()[0].get();
        REQUIRE(firstClassStatement->GetKind() == Parser::ParsedStatementKind::VirtProp);
        auto virtPropStatement =
            dynamic_cast<const MalachScript::Parser::ParsedVirtPropStatement*>(firstClassStatement);
        REQUIRE(virtPropStatement->GetAccess() == MalachScript::AccessModifier::Public);
        REQUIRE(virtPropStatement->GetIdentifier().GetString() == u8"foo");
        REQUIRE(virtPropStatement->HasGet());
        REQUIRE(virtPropStatement->HasSet());
        REQUIRE_FALSE(virtPropStatement->IsGetConst());
        REQUIRE_FALSE(virtPropStatement->IsSetConst());
        REQUIRE(virtPropStatement->GetGetStatement() == nullptr);
        REQUIRE(virtPropStatement->GetSetStatement() == nullptr);
    })

PARSER_TEST(
    "Parse class foobar { bool foo { get const; set const; } }",
    PARSER_TEST_TOKENS(new Parser::LexTokenImpl<Parser::LexTokenKind::ClassKeyword>(TextSpan(0, 0)),
                       new Parser::IdentifierToken(TextSpan(0, 0), u8"foobar"),
                       new Parser::LexTokenImpl<Parser::LexTokenKind::OpenCurlyParenthesisSymbol>(TextSpan(0, 0)),
                       new Parser::IdentifierToken(TextSpan(0, 0), u8"bool"),
                       new Parser::IdentifierToken(TextSpan(0, 0), u8"foo"),
                       new Parser::LexTokenImpl<Parser::LexTokenKind::OpenCurlyParenthesisSymbol>(TextSpan(0, 0)),
                       new Parser::LexTokenImpl<Parser::LexTokenKind::GetKeyword>(TextSpan(0, 0)),
                       new Parser::LexTokenImpl<Parser::LexTokenKind::ConstKeyword>(TextSpan(0, 0)),
                       new Parser::LexTokenImpl<Parser::LexTokenKind::SemicolonSymbol>(TextSpan(0, 0)),
                       new Parser::LexTokenImpl<Parser::LexTokenKind::SetKeyword>(TextSpan(0, 0)),
                       new Parser::LexTokenImpl<Parser::LexTokenKind::ConstKeyword>(TextSpan(0, 0)),
                       new Parser::LexTokenImpl<Parser::LexTokenKind::SemicolonSymbol>(TextSpan(0, 0)),
                       new Parser::LexTokenImpl<Parser::LexTokenKind::CloseCurlyParenthesisSymbol>(TextSpan(0, 0)),
                       new Parser::LexTokenImpl<Parser::LexTokenKind::CloseCurlyParenthesisSymbol>(TextSpan(0, 0))),
    {
        REQUIRE(script->GetStatements().size() == 1);
        auto firstStatement = script->GetStatements()[0].get();
        REQUIRE(firstStatement->GetKind() == Parser::ParsedStatementKind::Class);
        auto firstClassStatement =
            dynamic_cast<const MalachScript::Parser::ParsedClassStatement*>(firstStatement)->GetBody()[0].get();
        REQUIRE(firstClassStatement->GetKind() == Parser::ParsedStatementKind::VirtProp);
        auto virtPropStatement =
            dynamic_cast<const MalachScript::Parser::ParsedVirtPropStatement*>(firstClassStatement);
        REQUIRE(virtPropStatement->GetAccess() == MalachScript::AccessModifier::Public);
        REQUIRE(virtPropStatement->GetIdentifier().GetString() == u8"foo");
        REQUIRE(virtPropStatement->HasGet());
        REQUIRE(virtPropStatement->HasSet());
        REQUIRE(virtPropStatement->IsGetConst());
        REQUIRE(virtPropStatement->IsSetConst());
        REQUIRE(virtPropStatement->GetGetStatement() == nullptr);
        REQUIRE(virtPropStatement->GetSetStatement() == nullptr);
    })

PARSER_TEST(
    "Parse class foobar { bool foo { get const override; set const override; } }",
    PARSER_TEST_TOKENS(new Parser::LexTokenImpl<Parser::LexTokenKind::ClassKeyword>(TextSpan(0, 0)),
                       new Parser::IdentifierToken(TextSpan(0, 0), u8"foobar"),
                       new Parser::LexTokenImpl<Parser::LexTokenKind::OpenCurlyParenthesisSymbol>(TextSpan(0, 0)),
                       new Parser::IdentifierToken(TextSpan(0, 0), u8"bool"),
                       new Parser::IdentifierToken(TextSpan(0, 0), u8"foo"),
                       new Parser::LexTokenImpl<Parser::LexTokenKind::OpenCurlyParenthesisSymbol>(TextSpan(0, 0)),
                       new Parser::LexTokenImpl<Parser::LexTokenKind::GetKeyword>(TextSpan(0, 0)),
                       new Parser::LexTokenImpl<Parser::LexTokenKind::ConstKeyword>(TextSpan(0, 0)),
                       new Parser::LexTokenImpl<Parser::LexTokenKind::OverrideKeyword>(TextSpan(0, 0)),
                       new Parser::LexTokenImpl<Parser::LexTokenKind::SemicolonSymbol>(TextSpan(0, 0)),
                       new Parser::LexTokenImpl<Parser::LexTokenKind::SetKeyword>(TextSpan(0, 0)),
                       new Parser::LexTokenImpl<Parser::LexTokenKind::ConstKeyword>(TextSpan(0, 0)),
                       new Parser::LexTokenImpl<Parser::LexTokenKind::OverrideKeyword>(TextSpan(0, 0)),
                       new Parser::LexTokenImpl<Parser::LexTokenKind::SemicolonSymbol>(TextSpan(0, 0)),
                       new Parser::LexTokenImpl<Parser::LexTokenKind::CloseCurlyParenthesisSymbol>(TextSpan(0, 0)),
                       new Parser::LexTokenImpl<Parser::LexTokenKind::CloseCurlyParenthesisSymbol>(TextSpan(0, 0))),
    {
        REQUIRE(script->GetStatements().size() == 1);
        auto firstStatement = script->GetStatements()[0].get();
        REQUIRE(firstStatement->GetKind() == Parser::ParsedStatementKind::Class);
        auto firstClassStatement =
            dynamic_cast<const MalachScript::Parser::ParsedClassStatement*>(firstStatement)->GetBody()[0].get();
        REQUIRE(firstClassStatement->GetKind() == Parser::ParsedStatementKind::VirtProp);
        auto virtPropStatement =
            dynamic_cast<const MalachScript::Parser::ParsedVirtPropStatement*>(firstClassStatement);
        REQUIRE(virtPropStatement->GetAccess() == MalachScript::AccessModifier::Public);
        REQUIRE(virtPropStatement->GetIdentifier().GetString() == u8"foo");
        REQUIRE(virtPropStatement->HasGet());
        REQUIRE(virtPropStatement->HasSet());
        REQUIRE(virtPropStatement->IsGetConst());
        REQUIRE(virtPropStatement->IsSetConst());
        REQUIRE(FuncAttrHelpers::Contains(virtPropStatement->GetGetFuncAttr(), FuncAttr::Override));
        REQUIRE(FuncAttrHelpers::Contains(virtPropStatement->GetSetFuncAttr(), FuncAttr::Override));
        REQUIRE(virtPropStatement->GetGetStatement() == nullptr);
        REQUIRE(virtPropStatement->GetSetStatement() == nullptr);
    })

/// Parse class foobar {
//    int i;
//    bool foo {
//        get {
//            if (true) return true;
//            return false;
//        }
//        set{
//            if (1 == 1) i++;
//            i--;
//        }
//    }
//}

PARSER_TEST(
    "Virtprops with bodies",
    PARSER_TEST_TOKENS(new Parser::LexTokenImpl<Parser::LexTokenKind::ClassKeyword>(TextSpan(0, 0)),
                       new Parser::IdentifierToken(TextSpan(0, 0), u8"foobar"),
                       new Parser::LexTokenImpl<Parser::LexTokenKind::OpenCurlyParenthesisSymbol>(TextSpan(0, 0)),
                       new Parser::IdentifierToken(TextSpan(0, 0), u8"int"),
                       new Parser::IdentifierToken(TextSpan(0, 0), u8"i"),
                       new Parser::LexTokenImpl<Parser::LexTokenKind::SemicolonSymbol>(TextSpan(0, 0)),

                       new Parser::IdentifierToken(TextSpan(0, 0), u8"bool"),
                       new Parser::IdentifierToken(TextSpan(0, 0), u8"foo"),
                       new Parser::LexTokenImpl<Parser::LexTokenKind::OpenCurlyParenthesisSymbol>(TextSpan(0, 0)),
                       new Parser::LexTokenImpl<Parser::LexTokenKind::GetKeyword>(TextSpan(0, 0)),
                       new Parser::LexTokenImpl<Parser::LexTokenKind::OpenCurlyParenthesisSymbol>(TextSpan(0, 0)),
                       new Parser::LexTokenImpl<Parser::LexTokenKind::IfKeyword>(TextSpan(0, 0)),
                       new Parser::LexTokenImpl<Parser::LexTokenKind::OpenParenthesisSymbol>(TextSpan(0, 0)),
                       new Parser::LexTokenImpl<Parser::LexTokenKind::TrueKeyword>(TextSpan(0, 0)),
                       new Parser::LexTokenImpl<Parser::LexTokenKind::CloseParenthesisSymbol>(TextSpan(0, 0)),
                       new Parser::LexTokenImpl<Parser::LexTokenKind::ReturnKeyword>(TextSpan(0, 0)),
                       new Parser::LexTokenImpl<Parser::LexTokenKind::TrueKeyword>(TextSpan(0, 0)),
                       new Parser::LexTokenImpl<Parser::LexTokenKind::SemicolonSymbol>(TextSpan(0, 0)),
                       new Parser::LexTokenImpl<Parser::LexTokenKind::ReturnKeyword>(TextSpan(0, 0)),
                       new Parser::LexTokenImpl<Parser::LexTokenKind::FalseKeyword>(TextSpan(0, 0)),
                       new Parser::LexTokenImpl<Parser::LexTokenKind::SemicolonSymbol>(TextSpan(0, 0)),
                       new Parser::LexTokenImpl<Parser::LexTokenKind::CloseCurlyParenthesisSymbol>(TextSpan(0, 0)),

                       new Parser::LexTokenImpl<Parser::LexTokenKind::SetKeyword>(TextSpan(0, 0)),
                       new Parser::LexTokenImpl<Parser::LexTokenKind::OpenCurlyParenthesisSymbol>(TextSpan(0, 0)),
                       new Parser::LexTokenImpl<Parser::LexTokenKind::IfKeyword>(TextSpan(0, 0)),
                       new Parser::LexTokenImpl<Parser::LexTokenKind::OpenParenthesisSymbol>(TextSpan(0, 0)),
                       new Parser::IntegerLiteral(TextSpan(0, 0), 1),
                       new Parser::LexTokenImpl<Parser::LexTokenKind::EqualsEqualsSymbol>(TextSpan(0, 0)),
                       new Parser::IntegerLiteral(TextSpan(0, 0), 1),
                       new Parser::LexTokenImpl<Parser::LexTokenKind::CloseParenthesisSymbol>(TextSpan(0, 0)),
                       new Parser::IdentifierToken(TextSpan(0, 0), u8"i"),
                       new Parser::LexTokenImpl<Parser::LexTokenKind::PlusPlusSymbol>(TextSpan(0, 0)),
                       new Parser::LexTokenImpl<Parser::LexTokenKind::SemicolonSymbol>(TextSpan(0, 0)),
                       new Parser::IdentifierToken(TextSpan(0, 0), u8"i"),
                       new Parser::LexTokenImpl<Parser::LexTokenKind::MinusMinusSymbol>(TextSpan(0, 0)),
                       new Parser::LexTokenImpl<Parser::LexTokenKind::SemicolonSymbol>(TextSpan(0, 0)),

                       new Parser::LexTokenImpl<Parser::LexTokenKind::CloseCurlyParenthesisSymbol>(TextSpan(0, 0)),
                       new Parser::LexTokenImpl<Parser::LexTokenKind::CloseCurlyParenthesisSymbol>(TextSpan(0, 0)),
                       new Parser::LexTokenImpl<Parser::LexTokenKind::CloseCurlyParenthesisSymbol>(TextSpan(0, 0))),
    {
        REQUIRE(script->GetStatements().size() == 1);
        auto firstStatement = script->GetStatements()[0].get();
        REQUIRE(firstStatement->GetKind() == Parser::ParsedStatementKind::Class);
        auto firstClassStatement =
            dynamic_cast<const MalachScript::Parser::ParsedClassStatement*>(firstStatement)->GetBody()[1].get();
        REQUIRE(firstClassStatement->GetKind() == Parser::ParsedStatementKind::VirtProp);
        auto virtPropStatement =
            dynamic_cast<const MalachScript::Parser::ParsedVirtPropStatement*>(firstClassStatement);
        REQUIRE(virtPropStatement->GetAccess() == MalachScript::AccessModifier::Public);
        REQUIRE(virtPropStatement->GetIdentifier().GetString() == u8"foo");
        REQUIRE(virtPropStatement->HasGet());
        REQUIRE(virtPropStatement->HasSet());
        REQUIRE_FALSE(virtPropStatement->IsGetConst());
        REQUIRE_FALSE(virtPropStatement->IsSetConst());
        REQUIRE(virtPropStatement->GetGetStatement() != nullptr);
        REQUIRE(virtPropStatement->GetSetStatement() != nullptr);
    })