diff --git a/src/Binder/Binder.cpp b/src/Binder/Binder.cpp index c533fe0..5f09bf8 100644 --- a/src/Binder/Binder.cpp +++ b/src/Binder/Binder.cpp @@ -26,6 +26,7 @@ BoundStatement* Binder::BindStatement(ParsedStatement* statement){ case ParsedStatementKind::Assignment: return this -> BindAssignmentStatement(statement); case ParsedStatementKind ::FunctionDeclaration: return this->BindFunctionDeclarationStatement(statement); case ParsedStatementKind::Return: return this -> BindReturnStatement(statement); + case ParsedStatementKind::Conditional: return this -> BindConditionalStatement(statement); case ParsedStatementKind::Bad: return new BoundBadStatement(); } @@ -131,6 +132,21 @@ BoundStatement *Binder::BindReturnStatement(ParsedStatement* statement){ return new BoundReturnStatement(boundExpression); } +BoundStatement *Binder::BindConditionalStatement(ParsedStatement* statement) { + auto conditionalStatement = (ParsedConditionalStatement*)statement; + auto boundCondition = this -> BindExpression(conditionalStatement -> GetCondition()); + if (boundCondition->GetType() -> GetClass() != TypeClass::Bool){ + this -> _scriptData -> Diagnostics -> LogError(DiagnosticCode::ConditionNotABool, statement->GetStartPosition(), statement->GetLength()); + return new BoundBadStatement(); + } + auto boundBlock = this -> BindStatement(conditionalStatement->GetBlock()); + BoundStatement* elseStatement = nullptr; + if (conditionalStatement->GetElseStatement() != nullptr){ + elseStatement = this -> BindStatement(conditionalStatement->GetElseStatement()); + } + return new BoundConditionalStatement(boundCondition, boundBlock, elseStatement); +} + BoundExpression* Binder::BindExpression(ParsedExpression* expression){ switch (expression -> GetKind()){ case ParsedExpressionKind ::LiteralInteger: @@ -351,3 +367,4 @@ BoundExpression *Binder::BindIndexExpression(IndexExpression *expression) { auto resultType = shared_ptr(indexer->GetType()->GetIndexedType(index->GetType().get())); return new BoundIndexExpression(indexer, index, resultType, expression->GetStartPosition(), expression->GetLength()); } + diff --git a/src/Binder/Binder.hpp b/src/Binder/Binder.hpp index fb02bd8..118a34b 100644 --- a/src/Binder/Binder.hpp +++ b/src/Binder/Binder.hpp @@ -20,6 +20,7 @@ class Binder { BoundStatement *BindAssignmentStatement(ParsedStatement *statement); BoundStatement *BindFunctionDeclarationStatement(ParsedStatement * statement); BoundStatement *BindReturnStatement(ParsedStatement *statement); + BoundStatement *BindConditionalStatement(ParsedStatement *statement); BoundExpression *BindExpression(ParsedExpression *expression); BoundExpression *BindVariableExpression(VariableExpression *expression); diff --git a/src/Binder/BoundStatements/BoundStatement.hpp b/src/Binder/BoundStatements/BoundStatement.hpp index adec48e..ef2dc55 100644 --- a/src/Binder/BoundStatements/BoundStatement.hpp +++ b/src/Binder/BoundStatements/BoundStatement.hpp @@ -18,6 +18,7 @@ enum class BoundStatementKind{ Assignment, FunctionDeclaration, Return, + Conditional, }; class BoundStatement{ @@ -135,6 +136,41 @@ public: } }; +class BoundConditionalStatement : public BoundStatement{ + BoundExpression* _condition; + BoundStatement* _block; + BoundStatement* _elseStatement; +public: + explicit BoundConditionalStatement(BoundExpression* condition, BoundStatement* block, BoundStatement* next){ + _condition = condition; + _block = block; + _elseStatement = next; + } + + ~BoundConditionalStatement() final{ + delete _condition; + delete _block; + delete _elseStatement; + } + + BoundStatementKind GetKind() final{ + return BoundStatementKind ::Conditional; + } + + BoundExpression* GetCondition(){ + return _condition; + } + + BoundStatement* GetBlock(){ + return _block; + } + + BoundStatement* GetElseStatement(){ + return _elseStatement; + } +}; + + #include "BoundFunctionDeclarationStatement.hpp" #endif //PORYGONLANG_BOUNDSTATEMENT_HPP diff --git a/src/Diagnostics/DiagnosticCode.hpp b/src/Diagnostics/DiagnosticCode.hpp index d174d48..c7df63f 100644 --- a/src/Diagnostics/DiagnosticCode.hpp +++ b/src/Diagnostics/DiagnosticCode.hpp @@ -20,6 +20,7 @@ enum class DiagnosticCode{ ParameterTypeMismatch, CantIndex, InvalidReturnType, + ConditionNotABool, }; #endif //PORYGONLANG_DIAGNOSTICCODE_HPP diff --git a/src/Evaluator/Evaluator.cpp b/src/Evaluator/Evaluator.cpp index b69d6fd..93efac6 100644 --- a/src/Evaluator/Evaluator.cpp +++ b/src/Evaluator/Evaluator.cpp @@ -23,6 +23,7 @@ void Evaluator::EvaluateStatement(BoundStatement *statement) { case BoundStatementKind ::Assignment: return this -> EvaluateAssignmentStatement((BoundAssignmentStatement*)statement); case BoundStatementKind ::FunctionDeclaration: return this->EvaluateFunctionDeclarationStatement((BoundFunctionDeclarationStatement*)statement); case BoundStatementKind::Return: return this -> EvaluateReturnStatement((BoundReturnStatement*)statement); + case BoundStatementKind::Conditional: return this -> EvaluateConditionalStatement((BoundConditionalStatement*)statement); case BoundStatementKind::Bad: throw; @@ -76,6 +77,18 @@ void Evaluator::EvaluateReturnStatement(BoundReturnStatement* statement){ this -> _returnValue = value; } +void Evaluator::EvaluateConditionalStatement(BoundConditionalStatement *statement) { + auto condition = statement->GetCondition(); + if (EvaluateBoolExpression(condition) -> EvaluateBool()){ + this -> EvaluateStatement(statement->GetBlock()); + } else{ + auto elseStatement = statement -> GetElseStatement(); + if (elseStatement != nullptr){ + this->EvaluateStatement(elseStatement); + } + } +} + shared_ptr Evaluator::EvaluateExpression(BoundExpression *expression) { auto type = expression -> GetType(); switch (type->GetClass()){ @@ -216,4 +229,5 @@ shared_ptr Evaluator::EvaluateIndexExpression(BoundExpression *expres auto index = this -> EvaluateExpression(indexExpression->GetIndexExpression()); auto indexable = this -> EvaluateExpression(indexExpression->GetIndexableExpression()); return shared_ptr(indexable -> IndexValue(index.get())); -} \ No newline at end of file +} + diff --git a/src/Evaluator/Evaluator.hpp b/src/Evaluator/Evaluator.hpp index 320635c..af5e37c 100644 --- a/src/Evaluator/Evaluator.hpp +++ b/src/Evaluator/Evaluator.hpp @@ -28,6 +28,7 @@ class Evaluator { void EvaluateAssignmentStatement(BoundAssignmentStatement* statement); void EvaluateFunctionDeclarationStatement(BoundFunctionDeclarationStatement *statement); void EvaluateReturnStatement(BoundReturnStatement *statement); + void EvaluateConditionalStatement(BoundConditionalStatement *statement); shared_ptr EvaluateExpression(BoundExpression* expression); shared_ptr EvaluateIntegerExpression(BoundExpression* expression); diff --git a/src/Parser/ParsedStatements/ParsedStatement.hpp b/src/Parser/ParsedStatements/ParsedStatement.hpp index c69a2ea..ff9918a 100644 --- a/src/Parser/ParsedStatements/ParsedStatement.hpp +++ b/src/Parser/ParsedStatements/ParsedStatement.hpp @@ -16,7 +16,8 @@ enum class ParsedStatementKind{ Expression, Assignment, FunctionDeclaration, - Return + Return, + Conditional }; class ParsedStatement { @@ -187,4 +188,47 @@ public: } }; +class ParsedConditionalStatement : public ParsedStatement{ + ParsedExpression* _condition; + ParsedStatement* _block; + // This can be either else if or else + ParsedStatement* _elseStatement; +public: + ParsedConditionalStatement(ParsedExpression* condition, ParsedStatement* block, unsigned int start, unsigned int length) + : ParsedStatement(start, length){ + _condition = condition; + _block = block; + _elseStatement = nullptr; + } + + ParsedConditionalStatement(ParsedExpression* condition, ParsedStatement* block, ParsedStatement* nextStatement, unsigned int start, unsigned int length) + : ParsedStatement(start, length){ + _condition = condition; + _block = block; + _elseStatement = nextStatement; + } + + ~ParsedConditionalStatement() final{ + delete _condition; + delete _block; + delete _elseStatement; + } + + ParsedStatementKind GetKind() final{ + return ParsedStatementKind ::Conditional; + } + + ParsedExpression* GetCondition(){ + return _condition; + } + + ParsedStatement* GetBlock(){ + return _block; + } + + ParsedStatement* GetElseStatement(){ + return _elseStatement; + } +}; + #endif //PORYGONLANG_PARSEDSTATEMENT_HPP diff --git a/src/Parser/Parser.cpp b/src/Parser/Parser.cpp index 3735ff7..912987e 100644 --- a/src/Parser/Parser.cpp +++ b/src/Parser/Parser.cpp @@ -23,6 +23,10 @@ IToken *Parser::Peek() { return this -> _tokens[_position]; } +IToken *Parser::PeekAt(int offset) { + return this -> _tokens[_position + offset]; +} + IToken *Parser::Next() { this -> _position++; return this -> _tokens[_position - 1]; @@ -34,6 +38,7 @@ ParsedStatement* Parser::ParseStatement(IToken* current){ case TokenKind ::LocalKeyword: return this -> ParseAssignment(current); case TokenKind ::FunctionKeyword: return this -> ParseFunctionDeclaration(current); case TokenKind ::ReturnKeyword: return this->ParseReturnStatement(current); + case TokenKind ::IfKeyword: return this -> ParseIfStatement(current, false); default: break; } if (this->Peek()->GetKind() == TokenKind::AssignmentToken){ @@ -144,6 +149,26 @@ ParsedStatement* Parser::ParseReturnStatement(IToken* current){ return new ParsedReturnStatement(expression, start, expression->GetEndPosition() - start); } +ParsedStatement* Parser::ParseIfStatement(IToken* current, bool isElseIf){ + auto condition = this->ParseExpression(this->Next()); + auto next = this -> Next(); + if (next->GetKind() != TokenKind::ThenKeyword){ + this -> ScriptData -> Diagnostics -> LogError(DiagnosticCode::UnexpectedToken, next->GetStartPosition(), next->GetLength()); + return new ParsedBadStatement(next->GetStartPosition(), next->GetLength()); + } + auto block = this -> ParseBlock({TokenKind ::EndKeyword, TokenKind ::ElseKeyword, TokenKind ::ElseIfKeyword}); + auto closeToken = this->PeekAt(-1); + auto start = current->GetStartPosition(); + if (closeToken->GetKind() == TokenKind::ElseIfKeyword){ + auto elseIfStatement = this -> ParseIfStatement(closeToken, true); + return new ParsedConditionalStatement(condition, block, elseIfStatement, start, elseIfStatement->GetEndPosition() - start); + } else if (closeToken->GetKind() == TokenKind::ElseKeyword){ + auto elseStatement = this -> ParseBlock({TokenKind ::EndKeyword}); + return new ParsedConditionalStatement(condition, block, elseStatement, start, elseStatement->GetEndPosition() - start); + } + return new ParsedConditionalStatement(condition, block, start, block->GetEndPosition() - start); +} + ParsedExpression* Parser::ParseExpression(IToken* current){ auto expression = this -> ParseBinaryExpression(current, OperatorPrecedence::No); auto peekKind = this->Peek()->GetKind(); diff --git a/src/Parser/Parser.hpp b/src/Parser/Parser.hpp index be562e5..abf0daf 100644 --- a/src/Parser/Parser.hpp +++ b/src/Parser/Parser.hpp @@ -30,6 +30,7 @@ class Parser { ParsedStatement *ParseBlock(const vector& endTokens); ParsedStatement* ParseFunctionDeclaration(IToken* current); ParsedStatement *ParseReturnStatement(IToken *current); + ParsedStatement *ParseIfStatement(IToken *current, bool isElseIf); ParsedExpression* ParseExpression(IToken* current); ParsedExpression* ParseBinaryExpression(IToken* current, OperatorPrecedence parentPrecedence); @@ -46,6 +47,8 @@ public: ScriptData = scriptData; } + + IToken *PeekAt(int offset); }; diff --git a/tests/integration/ConditionalTests.cpp b/tests/integration/ConditionalTests.cpp new file mode 100644 index 0000000..de6309b --- /dev/null +++ b/tests/integration/ConditionalTests.cpp @@ -0,0 +1,43 @@ +#ifdef TESTS_BUILD +#include +#include "../src/Script.hpp" + +TEST_CASE( "Basic conditional", "[integration]" ) { + Script* script = Script::Create("if true then foo = true end"); + REQUIRE(!script->Diagnostics -> HasErrors()); + auto variable = script->GetVariable("foo"); + REQUIRE(variable == nullptr); + script->Evaluate(); + variable = script->GetVariable("foo"); + REQUIRE(variable != nullptr); + REQUIRE(variable->EvaluateBool()); + delete script; +} + +TEST_CASE( "If then, else", "[integration]" ) { + Script* script = Script::Create("if false then foo = false else foo = true end"); + REQUIRE(!script->Diagnostics -> HasErrors()); + auto variable = script->GetVariable("foo"); + REQUIRE(variable == nullptr); + script->Evaluate(); + variable = script->GetVariable("foo"); + REQUIRE(variable != nullptr); + REQUIRE(variable->EvaluateBool()); + delete script; +} + +TEST_CASE( "If then, else if", "[integration]" ) { + Script* script = Script::Create("if false then foo = false elseif true then foo = true end"); + REQUIRE(!script->Diagnostics -> HasErrors()); + auto variable = script->GetVariable("foo"); + REQUIRE(variable == nullptr); + script->Evaluate(); + variable = script->GetVariable("foo"); + REQUIRE(variable != nullptr); + REQUIRE(variable->EvaluateBool()); + delete script; +} + + +#endif +