diff --git a/src/Binder/Binder.cpp b/src/Binder/Binder.cpp index 4433e0d..c533fe0 100644 --- a/src/Binder/Binder.cpp +++ b/src/Binder/Binder.cpp @@ -25,6 +25,7 @@ BoundStatement* Binder::BindStatement(ParsedStatement* statement){ case ParsedStatementKind ::Expression: return this -> BindExpressionStatement(statement); case ParsedStatementKind::Assignment: return this -> BindAssignmentStatement(statement); case ParsedStatementKind ::FunctionDeclaration: return this->BindFunctionDeclarationStatement(statement); + case ParsedStatementKind::Return: return this -> BindReturnStatement(statement); case ParsedStatementKind::Bad: return new BoundBadStatement(); } @@ -105,6 +106,31 @@ BoundStatement *Binder::BindFunctionDeclarationStatement(ParsedStatement *statem return new BoundBadStatement(); } +BoundStatement *Binder::BindReturnStatement(ParsedStatement* statement){ + auto expression = ((ParsedReturnStatement*)statement)->GetExpression(); + shared_ptr currentReturnType; + if (this->_currentFunction == nullptr){ + currentReturnType = this->_scriptData->GetReturnType(); + } else{ + currentReturnType = this->_currentFunction; + } + if (expression == nullptr && currentReturnType != nullptr){ + this -> _scriptData -> Diagnostics -> LogError(DiagnosticCode::InvalidReturnType, statement->GetStartPosition(), statement->GetLength()); + return new BoundBadStatement(); + } + auto boundExpression = this->BindExpression(expression); + auto expresionType = boundExpression->GetType(); + if (currentReturnType == nullptr){ + currentReturnType.swap(expresionType); + return new BoundReturnStatement(boundExpression); + } + if (currentReturnType.get()->operator!=(expresionType.get())){ + this -> _scriptData -> Diagnostics -> LogError(DiagnosticCode::InvalidReturnType, statement->GetStartPosition(), statement->GetLength()); + return new BoundBadStatement(); + } + return new BoundReturnStatement(boundExpression); +} + BoundExpression* Binder::BindExpression(ParsedExpression* expression){ switch (expression -> GetKind()){ case ParsedExpressionKind ::LiteralInteger: diff --git a/src/Binder/Binder.hpp b/src/Binder/Binder.hpp index 4279a46..fb02bd8 100644 --- a/src/Binder/Binder.hpp +++ b/src/Binder/Binder.hpp @@ -10,6 +10,7 @@ class Binder { Script* _scriptData; BoundScope* _scope; + shared_ptr _currentFunction; ~Binder(); @@ -18,6 +19,7 @@ class Binder { BoundStatement *BindExpressionStatement(ParsedStatement *statement); BoundStatement *BindAssignmentStatement(ParsedStatement *statement); BoundStatement *BindFunctionDeclarationStatement(ParsedStatement * statement); + BoundStatement *BindReturnStatement(ParsedStatement *statement); BoundExpression *BindExpression(ParsedExpression *expression); BoundExpression *BindVariableExpression(VariableExpression *expression); diff --git a/src/Binder/BoundStatements/BoundStatement.hpp b/src/Binder/BoundStatements/BoundStatement.hpp index 1f75a52..adec48e 100644 --- a/src/Binder/BoundStatements/BoundStatement.hpp +++ b/src/Binder/BoundStatements/BoundStatement.hpp @@ -1,9 +1,8 @@ -#include #ifndef PORYGONLANG_BOUNDSTATEMENT_HPP #define PORYGONLANG_BOUNDSTATEMENT_HPP - +#include #include #include "../BoundExpressions/BoundExpression.hpp" #include "../BoundVariables/BoundVariableKey.hpp" @@ -18,6 +17,7 @@ enum class BoundStatementKind{ Expression, Assignment, FunctionDeclaration, + Return, }; class BoundStatement{ @@ -116,6 +116,25 @@ public: } }; +class BoundReturnStatement : public BoundStatement{ + BoundExpression* _expression; +public: + explicit BoundReturnStatement(BoundExpression* expression){ + _expression = expression; + } + ~BoundReturnStatement() final{ + delete _expression; + } + + BoundStatementKind GetKind() final{ + return BoundStatementKind ::Return; + } + + BoundExpression* GetExpression(){ + return _expression; + } +}; + #include "BoundFunctionDeclarationStatement.hpp" #endif //PORYGONLANG_BOUNDSTATEMENT_HPP diff --git a/src/Diagnostics/DiagnosticCode.hpp b/src/Diagnostics/DiagnosticCode.hpp index b1c567b..d174d48 100644 --- a/src/Diagnostics/DiagnosticCode.hpp +++ b/src/Diagnostics/DiagnosticCode.hpp @@ -19,6 +19,7 @@ enum class DiagnosticCode{ ParameterCountMismatch, ParameterTypeMismatch, CantIndex, + InvalidReturnType, }; #endif //PORYGONLANG_DIAGNOSTICCODE_HPP diff --git a/src/Evaluator/Evaluator.cpp b/src/Evaluator/Evaluator.cpp index f5bf255..b69d6fd 100644 --- a/src/Evaluator/Evaluator.cpp +++ b/src/Evaluator/Evaluator.cpp @@ -1,8 +1,4 @@ #include - -#include - - #include #include "Evaluator.hpp" #include "EvaluationException.hpp" @@ -18,12 +14,15 @@ void Evaluator::Evaluate(BoundScriptStatement *statement) { } void Evaluator::EvaluateStatement(BoundStatement *statement) { + if (this->_hasReturned) + return; switch (statement->GetKind()){ case BoundStatementKind ::Script: throw; // Should never happen case BoundStatementKind ::Block: return this -> EvaluateBlockStatement((BoundBlockStatement*)statement); case BoundStatementKind ::Expression: return this -> EvaluateExpressionStatement((BoundExpressionStatement*)statement); case BoundStatementKind ::Assignment: return this -> EvaluateAssignmentStatement((BoundAssignmentStatement*)statement); case BoundStatementKind ::FunctionDeclaration: return this->EvaluateFunctionDeclarationStatement((BoundFunctionDeclarationStatement*)statement); + case BoundStatementKind::Return: return this -> EvaluateReturnStatement((BoundReturnStatement*)statement); case BoundStatementKind::Bad: throw; @@ -34,6 +33,8 @@ void Evaluator::EvaluateBlockStatement(BoundBlockStatement* statement) { this->_evaluationScope->OuterScope(); for (auto s: statement->GetStatements()){ this -> EvaluateStatement(s); + if (this->_hasReturned) + break; } this->_evaluationScope->InnerScope(); } @@ -65,6 +66,16 @@ void Evaluator::EvaluateFunctionDeclarationStatement(BoundFunctionDeclarationSta } } +void Evaluator::EvaluateReturnStatement(BoundReturnStatement* statement){ + auto expression = statement->GetExpression(); + this->_hasReturned = true; + if (expression == nullptr){ + return; + } + auto value = this -> EvaluateExpression(expression); + this -> _returnValue = value; +} + shared_ptr Evaluator::EvaluateExpression(BoundExpression *expression) { auto type = expression -> GetType(); switch (type->GetClass()){ @@ -174,11 +185,13 @@ shared_ptr Evaluator::EvaluateFunctionCallExpression(BoundExpression* this->_evaluationScope->CreateVariable(key->GetScopeId(), key->GetIdentifier(), parameter->Clone()); } this->EvaluateBlockStatement(function->GetInnerBlock().get()); - return nullptr; - + this->_hasReturned = false; + auto r = this -> _returnValue; + this -> _returnValue = nullptr; + return r; } -EvalValue* Evaluator::EvaluateFunction(ScriptFunctionEvalValue *function, vector parameters) { +shared_ptr Evaluator::EvaluateFunction(ScriptFunctionEvalValue *function, vector parameters) { auto type = std::dynamic_pointer_cast(function->GetType()); auto parameterTypes = type->GetParameterTypes(); auto parameterKeys = type->GetParameterKeys(); @@ -192,7 +205,10 @@ EvalValue* Evaluator::EvaluateFunction(ScriptFunctionEvalValue *function, vector this->_evaluationScope->CreateVariable(key->GetScopeId(), key->GetIdentifier(), parameter->Clone()); } this->EvaluateBlockStatement(function->GetInnerBlock().get()); - return nullptr; + this->_hasReturned = false; + auto r = this -> _returnValue; + this -> _returnValue = nullptr; + return r; } shared_ptr Evaluator::EvaluateIndexExpression(BoundExpression *expression) { diff --git a/src/Evaluator/Evaluator.hpp b/src/Evaluator/Evaluator.hpp index cf5ddb3..320635c 100644 --- a/src/Evaluator/Evaluator.hpp +++ b/src/Evaluator/Evaluator.hpp @@ -15,7 +15,8 @@ using namespace std; class Evaluator { - shared_ptr _result; + shared_ptr _returnValue; + bool _hasReturned; shared_ptr _lastValue; Script* _scriptData; @@ -26,6 +27,7 @@ class Evaluator { void EvaluateExpressionStatement(BoundExpressionStatement* statement); void EvaluateAssignmentStatement(BoundAssignmentStatement* statement); void EvaluateFunctionDeclarationStatement(BoundFunctionDeclarationStatement *statement); + void EvaluateReturnStatement(BoundReturnStatement *statement); shared_ptr EvaluateExpression(BoundExpression* expression); shared_ptr EvaluateIntegerExpression(BoundExpression* expression); @@ -47,6 +49,8 @@ class Evaluator { public: explicit Evaluator(Script* script){ _scriptData = script; + _hasReturned = false; + _returnValue = nullptr; _evaluationScope = nullptr; } @@ -55,7 +59,7 @@ public: } void Evaluate(BoundScriptStatement* statement); - EvalValue* EvaluateFunction(ScriptFunctionEvalValue* func, vector parameters); + shared_ptr EvaluateFunction(ScriptFunctionEvalValue* func, vector parameters); EvalValue* GetLastValue(){ return _lastValue.get(); diff --git a/src/Parser/ParsedStatements/ParsedStatement.hpp b/src/Parser/ParsedStatements/ParsedStatement.hpp index 2093a1e..c69a2ea 100644 --- a/src/Parser/ParsedStatements/ParsedStatement.hpp +++ b/src/Parser/ParsedStatements/ParsedStatement.hpp @@ -15,7 +15,8 @@ enum class ParsedStatementKind{ Block, Expression, Assignment, - FunctionDeclaration + FunctionDeclaration, + Return }; class ParsedStatement { @@ -166,4 +167,24 @@ public: } }; +class ParsedReturnStatement : public ParsedStatement{ + ParsedExpression* _expression; +public: + ParsedReturnStatement(ParsedExpression* expression, unsigned int start, unsigned int length) : ParsedStatement(start, length){ + _expression = expression; + } + + ~ParsedReturnStatement() final{ + delete _expression; + } + + ParsedStatementKind GetKind() final{ + return ParsedStatementKind ::Return; + } + + ParsedExpression* GetExpression(){ + return _expression; + } +}; + #endif //PORYGONLANG_PARSEDSTATEMENT_HPP diff --git a/src/Parser/Parser.cpp b/src/Parser/Parser.cpp index c677e6b..3735ff7 100644 --- a/src/Parser/Parser.cpp +++ b/src/Parser/Parser.cpp @@ -33,6 +33,7 @@ ParsedStatement* Parser::ParseStatement(IToken* current){ switch (currentKind){ case TokenKind ::LocalKeyword: return this -> ParseAssignment(current); case TokenKind ::FunctionKeyword: return this -> ParseFunctionDeclaration(current); + case TokenKind ::ReturnKeyword: return this->ParseReturnStatement(current); default: break; } if (this->Peek()->GetKind() == TokenKind::AssignmentToken){ @@ -134,7 +135,13 @@ ParsedStatement *Parser::ParseFunctionDeclaration(IToken *current) { } auto functionIdentifier = ((IdentifierToken*) functionIdentifierToken)->Value; return new ParsedFunctionDeclarationStatement(HashedString(functionIdentifier), parameters, (ParsedBlockStatement*)block, start, block->GetEndPosition() - start); +} +ParsedStatement* Parser::ParseReturnStatement(IToken* current){ + //TODO: if next token is on a different line, don't parse it as return expression. + auto expression = this->ParseExpression(this->Next()); + auto start = current->GetStartPosition(); + return new ParsedReturnStatement(expression, start, expression->GetEndPosition() - start); } ParsedExpression* Parser::ParseExpression(IToken* current){ diff --git a/src/Parser/Parser.hpp b/src/Parser/Parser.hpp index 9877e0b..be562e5 100644 --- a/src/Parser/Parser.hpp +++ b/src/Parser/Parser.hpp @@ -29,6 +29,7 @@ class Parser { ParsedStatement* ParseAssignment(IToken* current); ParsedStatement *ParseBlock(const vector& endTokens); ParsedStatement* ParseFunctionDeclaration(IToken* current); + ParsedStatement *ParseReturnStatement(IToken *current); ParsedExpression* ParseExpression(IToken* current); ParsedExpression* ParseBinaryExpression(IToken* current, OperatorPrecedence parentPrecedence); diff --git a/src/Script.cpp b/src/Script.cpp index f3ae89a..4fbe3b1 100644 --- a/src/Script.cpp +++ b/src/Script.cpp @@ -72,7 +72,7 @@ bool Script::HasFunction(const string &key) { return f != _scriptVariables->end() && f.operator->()->second->GetType()->GetClass() == TypeClass ::Function; } -EvalValue *Script::CallFunction(const string &key, vector variables) { +shared_ptr Script::CallFunction(const string &key, vector variables) { auto var = (ScriptFunctionEvalValue*)GetVariable(key); return this->_evaluator->EvaluateFunction(var, std::move(variables)); } @@ -104,7 +104,7 @@ extern "C" { EvalValue* CallFunction(Script* script, const char* key, EvalValue* parameters[], int parameterCount){ std::vector v(parameters, parameters + parameterCount); - return script->CallFunction(key, v); + return script->CallFunction(key, v).get(); } } diff --git a/src/Script.hpp b/src/Script.hpp index c57691c..f0d125d 100644 --- a/src/Script.hpp +++ b/src/Script.hpp @@ -20,6 +20,7 @@ class Script { Evaluator* _evaluator; unordered_map>* _scriptVariables; BoundScriptStatement* _boundScript; + shared_ptr _returnType; explicit Script(); void Parse(string script); @@ -29,6 +30,10 @@ public: ~Script(); + shared_ptr GetReturnType(){ + return _returnType; + } + void Evaluate(); EvalValue* GetLastValue(); @@ -36,10 +41,8 @@ public: EvalValue* GetVariable(const string& key); bool HasVariable(const string& key); - EvalValue* CallFunction(const string& key, vector variables); + shared_ptr CallFunction(const string& key, vector variables); bool HasFunction(const string& key); - - }; diff --git a/tests/integration/Functions.cpp b/tests/integration/Functions.cpp index 37efbb4..2afbe9c 100644 --- a/tests/integration/Functions.cpp +++ b/tests/integration/Functions.cpp @@ -57,5 +57,32 @@ TEST_CASE( "Define script function and call from extern", "[integration]" ) { delete script; } +TEST_CASE( "Define script function and return", "[integration]" ) { + Script* script = Script::Create( + "val = 0\n" + "function add(number a, number b) \n" + "return a + b \n" + "val = val + 1\n" + "end"); + REQUIRE(!script->Diagnostics -> HasErrors()); + script->Evaluate(); + + REQUIRE(script->HasFunction("add")); + auto toAddVal = new IntegerEvalValue(5); + auto toAddVal2 = new IntegerEvalValue(6); + auto result = script->CallFunction("add", {toAddVal, toAddVal2}); + delete toAddVal; + delete toAddVal2; + + REQUIRE(result->GetType()->GetClass() == TypeClass::Number); + REQUIRE(result->EvaluateInteger() == 11); + + auto variable = script->GetVariable("val"); + REQUIRE(variable->GetType()->GetClass() == TypeClass::Number); + REQUIRE(variable->EvaluateInteger() == 0); + + delete script; +} + #endif