From 4408cf00cd78d9ed23747ea4125f6cd5bda58305 Mon Sep 17 00:00:00 2001 From: Deukhoofd Date: Sat, 1 Jun 2019 19:20:31 +0200 Subject: [PATCH] Large overhaul of pointers to shared_ptrs, implemented function evaluation --- src/Binder/Binder.cpp | 119 ++++++++++++------ src/Binder/Binder.hpp | 1 + .../BoundExpressions/BoundExpression.hpp | 64 +++++++--- .../BoundFunctionDeclarationStatement.hpp | 7 +- src/Binder/BoundVariables/BoundScope.cpp | 8 +- src/Binder/BoundVariables/BoundScope.hpp | 5 +- src/Binder/BoundVariables/BoundVariable.hpp | 9 +- src/Diagnostics/DiagnosticCode.hpp | 3 + src/Evaluator/EvalValues/EvalValue.hpp | 13 +- src/Evaluator/EvalValues/NumericEvalValue.hpp | 11 +- .../EvalValues/ScriptFunctionEvalValue.hpp | 32 ++--- src/Evaluator/EvalValues/StringEvalValue.hpp | 9 +- src/Evaluator/Evaluator.cpp | 52 +++++++- src/Evaluator/Evaluator.hpp | 6 + .../ParsedExpressions/ParsedExpression.hpp | 3 +- src/ScriptType.hpp | 35 ++++-- tests/integration/Functions.cpp | 13 ++ 17 files changed, 261 insertions(+), 129 deletions(-) diff --git a/src/Binder/Binder.cpp b/src/Binder/Binder.cpp index 7ecd1b8..5e4a87d 100644 --- a/src/Binder/Binder.cpp +++ b/src/Binder/Binder.cpp @@ -51,8 +51,8 @@ BoundStatement* Binder::BindAssignmentStatement(ParsedStatement *statement){ auto boundExpression = this->BindExpression(s->GetExpression()); VariableAssignment assignment = s->IsLocal() ? - this->_scope->CreateExplicitLocal(s->GetIdentifier().GetHash(), *boundExpression->GetType()) - : this->_scope->AssignVariable(s->GetIdentifier().GetHash(), *boundExpression->GetType()); + this->_scope->CreateExplicitLocal(s->GetIdentifier().GetHash(), boundExpression->GetType()) + : this->_scope->AssignVariable(s->GetIdentifier().GetHash(), boundExpression->GetType()); if (assignment.GetResult() == VariableAssignmentResult::Ok){ auto key = assignment.GetKey(); return new BoundAssignmentStatement(key, boundExpression); @@ -63,28 +63,28 @@ BoundStatement* Binder::BindAssignmentStatement(ParsedStatement *statement){ } } -ScriptType* ParseTypeIdentifier(HashedString s){ +std::shared_ptr ParseTypeIdentifier(HashedString s){ switch (s.GetHash()){ - case HashedString::ConstHash("number"): return new NumericScriptType(false, false); - case HashedString::ConstHash("bool"): return new ScriptType(TypeClass::Bool); - case HashedString::ConstHash("string"): return new ScriptType(TypeClass::String); - default: return new ScriptType(TypeClass::Error); // todo: change to userdata + case HashedString::ConstHash("number"): return std::make_shared(false, false); + case HashedString::ConstHash("bool"): return std::make_shared(TypeClass::Bool); + case HashedString::ConstHash("string"): return std::make_shared(TypeClass::String); + default: return std::make_shared(TypeClass::Error); // todo: change to userdata } } BoundStatement *Binder::BindFunctionDeclarationStatement(ParsedStatement *statement) { auto functionStatement = (ParsedFunctionDeclarationStatement*) statement; auto parameters = functionStatement->GetParameters(); - vector> parameterTypes = vector>(parameters.size()); - vector> parameterKeys = vector>(parameters.size()); + auto parameterTypes = new vector>(parameters.size()); + auto parameterKeys = new vector>(parameters.size()); this->_scope->GoInnerScope(); for (int i = 0; i < parameters.size(); i++){ auto var = parameters[i]; auto parsedType = ParseTypeIdentifier(var->GetType()); - parameterTypes[i] = std::shared_ptr(parsedType); - auto parameterAssignment = this->_scope->CreateExplicitLocal(var->GetIdentifier().GetHash(), *parsedType); + parameterTypes->at(i) = parsedType; + auto parameterAssignment = this->_scope->CreateExplicitLocal(var->GetIdentifier().GetHash(), parsedType); if (parameterAssignment.GetResult() == VariableAssignmentResult::Ok){ - parameterKeys[i] = std::shared_ptr(parameterAssignment.GetKey()); + parameterKeys -> at(i) = std::shared_ptr(parameterAssignment.GetKey()); } else{ //TODO: log error @@ -95,8 +95,10 @@ BoundStatement *Binder::BindFunctionDeclarationStatement(ParsedStatement *statem this->_scope->GoOuterScope(); auto identifier = functionStatement->GetIdentifier(); auto returnType = std::make_shared(TypeClass::Nil); - auto type = new FunctionScriptType(returnType, parameterTypes, parameterKeys); - auto assignment = this->_scope->AssignVariable(identifier.GetHash(), *type); + auto parameterTypesPtr = std::shared_ptr>>(parameterTypes); + auto parameterKeysPtr = std::shared_ptr>>(parameterKeys); + auto type = make_shared(returnType, parameterTypesPtr, parameterKeysPtr); + auto assignment = this->_scope->AssignVariable(identifier.GetHash(), type); if (assignment.GetResult() == VariableAssignmentResult::Ok){ return new BoundFunctionDeclarationStatement(type, assignment.GetKey(), (BoundBlockStatement*)boundBlock); } @@ -123,6 +125,8 @@ BoundExpression* Binder::BindExpression(ParsedExpression* expression){ case ParsedExpressionKind ::Parenthesized: return BindExpression(((ParenthesizedExpression*)expression)->GetInnerExpression()); + case ParsedExpressionKind ::FunctionCall: + return this->BindFunctionCall((FunctionCallExpression*)expression); case ParsedExpressionKind ::Bad: return new BoundBadExpression(expression->GetStartPosition(), expression-> GetLength()); @@ -151,84 +155,89 @@ BoundExpression* Binder::BindBinaryOperator(BinaryExpression* expression){ switch (expression->GetOperatorKind()){ case BinaryOperatorKind ::Addition: if (boundLeftType->GetClass() == TypeClass::Number && boundRightType->GetClass() == TypeClass::Number){ - auto leftNumeric = (NumericScriptType*)boundLeftType; - auto rightNumeric = (NumericScriptType*)boundRightType; + auto leftNumeric = std::dynamic_pointer_cast(boundLeftType); + auto rightNumeric = std::dynamic_pointer_cast(boundRightType); if (leftNumeric->IsAwareOfFloat() && rightNumeric->IsAwareOfFloat()){ return new BoundBinaryExpression(boundLeft, boundRight, - BoundBinaryOperation::Addition, new NumericScriptType(true, leftNumeric->IsFloat() || rightNumeric->IsFloat()), + BoundBinaryOperation::Addition, + std::make_shared(true, leftNumeric->IsFloat() || rightNumeric->IsFloat()), expression->GetStartPosition(), expression->GetLength()); } else{ - return new BoundBinaryExpression(boundLeft, boundRight, BoundBinaryOperation::Addition, new NumericScriptType(false, false), + return new BoundBinaryExpression(boundLeft, boundRight, BoundBinaryOperation::Addition, + std::make_shared(false, false), expression->GetStartPosition(), expression->GetLength()); } } else if (boundLeftType->GetClass() == TypeClass::String){ - return new BoundBinaryExpression(boundLeft, boundRight, BoundBinaryOperation::Concatenation, new ScriptType(TypeClass::String), + return new BoundBinaryExpression(boundLeft, boundRight, BoundBinaryOperation::Concatenation, std::make_shared(TypeClass::String), expression->GetStartPosition(), expression->GetLength()); } break; case BinaryOperatorKind ::Subtraction: if (boundLeftType->GetClass() == TypeClass::Number && boundRightType->GetClass() == TypeClass::Number){ - auto leftNumeric = (NumericScriptType*)boundLeftType; - auto rightNumeric = (NumericScriptType*)boundRightType; + auto leftNumeric = std::dynamic_pointer_cast(boundLeftType); + auto rightNumeric = std::dynamic_pointer_cast(boundRightType); if (leftNumeric->IsAwareOfFloat() && rightNumeric->IsAwareOfFloat()){ return new BoundBinaryExpression(boundLeft, boundRight, BoundBinaryOperation::Subtraction, - new NumericScriptType(true, leftNumeric->IsFloat() || rightNumeric->IsFloat()), + std::make_shared(true, leftNumeric->IsFloat() || rightNumeric->IsFloat()), expression->GetStartPosition(), expression->GetLength()); } else{ - return new BoundBinaryExpression(boundLeft, boundRight, BoundBinaryOperation::Subtraction, new NumericScriptType(false, false), + return new BoundBinaryExpression(boundLeft, boundRight, BoundBinaryOperation::Subtraction, + std::make_shared(false, false), expression->GetStartPosition(), expression->GetLength()); } } break; case BinaryOperatorKind ::Multiplication: if (boundLeftType->GetClass() == TypeClass::Number && boundRightType->GetClass() == TypeClass::Number){ - auto leftNumeric = (NumericScriptType*)boundLeftType; - auto rightNumeric = (NumericScriptType*)boundRightType; + auto leftNumeric = std::dynamic_pointer_cast(boundLeftType); + auto rightNumeric = std::dynamic_pointer_cast(boundRightType); if (leftNumeric->IsAwareOfFloat() && rightNumeric->IsAwareOfFloat()){ return new BoundBinaryExpression(boundLeft, boundRight, BoundBinaryOperation::Multiplication, - new NumericScriptType(true, leftNumeric->IsFloat() || rightNumeric->IsFloat()), + std::make_shared(true, leftNumeric->IsFloat() || rightNumeric->IsFloat()), expression->GetStartPosition(), expression->GetLength()); } else{ - return new BoundBinaryExpression(boundLeft, boundRight, BoundBinaryOperation::Multiplication, new NumericScriptType(false, false), + return new BoundBinaryExpression(boundLeft, boundRight, BoundBinaryOperation::Multiplication, + std::make_shared(false, false), expression->GetStartPosition(), expression->GetLength()); } } break; case BinaryOperatorKind ::Division: if (boundLeftType->GetClass() == TypeClass::Number && boundRightType->GetClass() == TypeClass::Number){ - auto leftNumeric = (NumericScriptType*)boundLeftType; - auto rightNumeric = (NumericScriptType*)boundRightType; + auto leftNumeric = std::dynamic_pointer_cast(boundLeftType); + auto rightNumeric = std::dynamic_pointer_cast(boundRightType); if (leftNumeric->IsAwareOfFloat() && rightNumeric->IsAwareOfFloat()){ return new BoundBinaryExpression(boundLeft, boundRight, BoundBinaryOperation::Division, - new NumericScriptType(true, leftNumeric->IsFloat() || rightNumeric->IsFloat()), + std::make_shared(true, leftNumeric->IsFloat() || rightNumeric->IsFloat()), expression->GetStartPosition(), expression->GetLength()); } else{ - return new BoundBinaryExpression(boundLeft, boundRight, BoundBinaryOperation::Division, new NumericScriptType(false, false), + return new BoundBinaryExpression(boundLeft, boundRight, BoundBinaryOperation::Division, + std::make_shared(false, false), expression->GetStartPosition(), expression->GetLength()); } } break; case BinaryOperatorKind ::Equality: - return new BoundBinaryExpression(boundLeft, boundRight, BoundBinaryOperation::Equality, new ScriptType(TypeClass::Bool), + return new BoundBinaryExpression(boundLeft, boundRight, BoundBinaryOperation::Equality, std::make_shared(TypeClass::Bool), expression->GetStartPosition(), expression->GetLength()); case BinaryOperatorKind ::Inequality: - return new BoundBinaryExpression(boundLeft, boundRight, BoundBinaryOperation::Inequality, new ScriptType(TypeClass::Bool), + return new BoundBinaryExpression(boundLeft, boundRight, BoundBinaryOperation::Inequality, std::make_shared(TypeClass::Bool), expression->GetStartPosition(), expression->GetLength()); case BinaryOperatorKind ::LogicalAnd: if (boundLeftType->GetClass() == TypeClass::Bool && boundRightType->GetClass() == TypeClass::Bool) - return new BoundBinaryExpression(boundLeft, boundRight, BoundBinaryOperation::LogicalAnd, new ScriptType(TypeClass::Bool), + return new BoundBinaryExpression(boundLeft, boundRight, BoundBinaryOperation::LogicalAnd, std::make_shared(TypeClass::Bool), expression->GetStartPosition(), expression->GetLength()); break; case BinaryOperatorKind ::LogicalOr: if (boundLeftType->GetClass() == TypeClass::Bool && boundRightType->GetClass() == TypeClass::Bool) - return new BoundBinaryExpression(boundLeft, boundRight, BoundBinaryOperation::LogicalOr, new ScriptType(TypeClass::Bool), + return new BoundBinaryExpression(boundLeft, boundRight, BoundBinaryOperation::LogicalOr, std::make_shared(TypeClass::Bool), expression->GetStartPosition(), expression->GetLength()); break; } @@ -248,14 +257,16 @@ BoundExpression* Binder::BindUnaryOperator(UnaryExpression* expression){ break; case UnaryOperatorKind ::Negation: if (operandType->GetClass() == TypeClass::Number){ - auto innerType = (NumericScriptType*)operandType; - return new BoundUnaryExpression(operand, BoundUnaryOperation::Negation, new NumericScriptType(innerType->IsAwareOfFloat(), - innerType->IsFloat()), expression->GetStartPosition(), expression->GetLength()); + auto innerType = std::dynamic_pointer_cast(operandType); + return new BoundUnaryExpression(operand, BoundUnaryOperation::Negation, + std::make_shared(innerType.get()->IsAwareOfFloat(), innerType.get()->IsFloat()), + expression->GetStartPosition(), expression->GetLength()); } break; case UnaryOperatorKind ::LogicalNegation: if (operandType->GetClass() == TypeClass::Bool){ - return new BoundUnaryExpression(operand, BoundUnaryOperation::LogicalNegation, new ScriptType(TypeClass::Bool), + return new BoundUnaryExpression(operand, BoundUnaryOperation::LogicalNegation, + std::make_shared(TypeClass::Bool), expression->GetStartPosition(), expression->GetLength()); } break; @@ -267,4 +278,34 @@ BoundExpression* Binder::BindUnaryOperator(UnaryExpression* expression){ } +BoundExpression* Binder::BindFunctionCall(FunctionCallExpression* expression){ + auto functionExpression = BindExpression(expression->GetFunction()); + auto type = functionExpression->GetType(); + if (type->GetClass() != TypeClass::Function){ + this->_scriptData->Diagnostics->LogError(DiagnosticCode::ExpressionIsNotAFunction, expression->GetStartPosition(), + expression->GetLength()); + return new BoundBadExpression(expression->GetStartPosition(), expression->GetLength()); + } + auto functionType = std::dynamic_pointer_cast(type); + auto parameterTypes = functionType->GetParameterTypes(); + auto givenParameters = expression->GetParameters(); + if (parameterTypes->size() != givenParameters.size()){ + this->_scriptData->Diagnostics->LogError(DiagnosticCode::ParameterCountMismatch, expression->GetStartPosition(), + expression->GetLength()); + return new BoundBadExpression(expression->GetStartPosition(), expression->GetLength()); + } + vector boundParameters = vector(givenParameters.size()); + for (int i = 0; i < givenParameters.size(); i++){ + auto parameter = givenParameters[i]; + auto boundParameter = this -> BindExpression(parameter); + if (boundParameter->GetType().get()->operator!=(parameterTypes.get()-> at(i).get())){ + this->_scriptData->Diagnostics->LogError(DiagnosticCode::ParameterTypeMismatch, parameter->GetStartPosition(), + parameter->GetLength()); + return new BoundBadExpression(expression->GetStartPosition(), expression->GetLength()); + } + boundParameters[i] = boundParameter; + } + return new BoundFunctionCallExpression(functionExpression, boundParameters, functionType.get()->GetReturnType(), + expression->GetStartPosition(), expression->GetLength()); +} \ No newline at end of file diff --git a/src/Binder/Binder.hpp b/src/Binder/Binder.hpp index 520ed5e..e13127e 100644 --- a/src/Binder/Binder.hpp +++ b/src/Binder/Binder.hpp @@ -23,6 +23,7 @@ class Binder { BoundExpression *BindVariableExpression(VariableExpression *expression); BoundExpression *BindBinaryOperator(BinaryExpression *expression); BoundExpression *BindUnaryOperator(UnaryExpression *expression); + BoundExpression *BindFunctionCall(FunctionCallExpression *expression); public: static BoundScriptStatement* Bind(Script* script, ParsedScriptStatement* s, BoundScope* scriptScope); diff --git a/src/Binder/BoundExpressions/BoundExpression.hpp b/src/Binder/BoundExpressions/BoundExpression.hpp index 40eb2b2..9585bca 100644 --- a/src/Binder/BoundExpressions/BoundExpression.hpp +++ b/src/Binder/BoundExpressions/BoundExpression.hpp @@ -1,10 +1,13 @@ #include +#include + #ifndef PORYGONLANG_BOUNDEXPRESSION_HPP #define PORYGONLANG_BOUNDEXPRESSION_HPP #include +#include #include "../../ScriptType.hpp" #include "../BoundOperators.hpp" #include "../BoundVariables/BoundVariableKey.hpp" @@ -22,24 +25,23 @@ enum class BoundExpressionKind{ Unary, Binary, + FunctionCall, }; class BoundExpression{ unsigned int _start; unsigned int _length; - ScriptType* _type; + std::shared_ptr _type; public: - BoundExpression(unsigned int start, unsigned int length, ScriptType* type){ + BoundExpression(unsigned int start, unsigned int length, std::shared_ptr type){ _start = start; _length = length; _type = type; } - virtual ~BoundExpression(){ - delete _type; - }; + virtual ~BoundExpression() = default; virtual BoundExpressionKind GetKind() = 0; - virtual ScriptType* GetType(){ + virtual std::shared_ptr GetType(){ return _type; }; @@ -56,7 +58,7 @@ public: class BoundBadExpression : public BoundExpression{ public: - BoundBadExpression(unsigned int start, unsigned int length) : BoundExpression(start, length, new ScriptType(TypeClass::Error)){} + BoundBadExpression(unsigned int start, unsigned int length) : BoundExpression(start, length, make_shared(TypeClass::Error)){} BoundExpressionKind GetKind() final{ return BoundExpressionKind ::Bad; @@ -67,7 +69,7 @@ class BoundLiteralIntegerExpression : public BoundExpression{ long _value; public: BoundLiteralIntegerExpression(long value, unsigned int start, unsigned int length) - : BoundExpression(start, length, new NumericScriptType(true, false)){ + : BoundExpression(start, length, make_shared(true, false)){ _value = value; } @@ -84,7 +86,7 @@ class BoundLiteralFloatExpression : public BoundExpression{ double _value; public: BoundLiteralFloatExpression(double value, unsigned int start, unsigned int length) - : BoundExpression(start, length, new NumericScriptType(true, true)){ + : BoundExpression(start, length, make_shared(true, true)){ _value = value; } @@ -101,7 +103,7 @@ class BoundLiteralStringExpression : public BoundExpression{ string _value; public: BoundLiteralStringExpression(string value, unsigned int start, unsigned int length) - : BoundExpression(start, length, new ScriptType(TypeClass::String)){ + : BoundExpression(start, length, make_shared(TypeClass::String)){ _value = std::move(value); } @@ -118,7 +120,7 @@ class BoundLiteralBoolExpression : public BoundExpression{ bool _value; public: BoundLiteralBoolExpression(bool value, unsigned int start, unsigned int length) - : BoundExpression(start, length, new ScriptType(TypeClass::Bool)){ + : BoundExpression(start, length, make_shared(TypeClass::Bool)){ _value = value; } @@ -134,20 +136,15 @@ public: class BoundVariableExpression : public BoundExpression{ int _scope; int _id; - ScriptType _type; public: - BoundVariableExpression(int scope, int id, const ScriptType& type, unsigned int start, unsigned int length) - : BoundExpression(start, length, nullptr), _type(type){ + BoundVariableExpression(int scope, int id, shared_ptr type, unsigned int start, unsigned int length) + : BoundExpression(start, length, std::move(type)){ _scope = scope; _id = id; } ~BoundVariableExpression() override = default; - ScriptType* GetType() final{ - return &_type; - }; - BoundExpressionKind GetKind() final{ return BoundExpressionKind ::Variable; } @@ -166,7 +163,7 @@ class BoundBinaryExpression : public BoundExpression { BoundExpression* _right; BoundBinaryOperation _operation; public: - BoundBinaryExpression(BoundExpression* left, BoundExpression* right, BoundBinaryOperation op, ScriptType* result, + BoundBinaryExpression(BoundExpression* left, BoundExpression* right, BoundBinaryOperation op, shared_ptr result, unsigned int start, unsigned int length) : BoundExpression(start, length, result){ _left = left; @@ -199,7 +196,7 @@ class BoundUnaryExpression : public BoundExpression { BoundExpression* _operand; BoundUnaryOperation _operation; public: - BoundUnaryExpression(BoundExpression* operand, BoundUnaryOperation op, ScriptType* result, unsigned int start, unsigned int length) + BoundUnaryExpression(BoundExpression* operand, BoundUnaryOperation op, shared_ptr result, unsigned int start, unsigned int length) :BoundExpression(start, length, result){ _operand = operand; _operation = op; @@ -222,6 +219,33 @@ public: } }; +class BoundFunctionCallExpression : public BoundExpression { + BoundExpression* _functionExpression; + vector _parameters; +public: + BoundFunctionCallExpression(BoundExpression *functionExpression, vector parameters, shared_ptr result, + unsigned int start, unsigned int length) + : BoundExpression(start, length, result), _functionExpression(functionExpression), _parameters(std::move(parameters)) {} + + ~BoundFunctionCallExpression() final{ + delete _functionExpression; + for (auto p : _parameters){ + delete p; + } + } + + BoundExpressionKind GetKind() final{ + return BoundExpressionKind ::FunctionCall; + } + + BoundExpression* GetFunctionExpression(){ + return _functionExpression; + } + + vector GetParameters(){ + return _parameters; + } +}; #endif //PORYGONLANG_BOUNDEXPRESSION_HPP diff --git a/src/Binder/BoundStatements/BoundFunctionDeclarationStatement.hpp b/src/Binder/BoundStatements/BoundFunctionDeclarationStatement.hpp index 06019f0..0c5869f 100644 --- a/src/Binder/BoundStatements/BoundFunctionDeclarationStatement.hpp +++ b/src/Binder/BoundStatements/BoundFunctionDeclarationStatement.hpp @@ -8,9 +8,9 @@ class BoundFunctionDeclarationStatement : public BoundStatement{ BoundVariableKey* _key; std::shared_ptr _block; - FunctionScriptType* _type; + std::shared_ptr _type; public: - BoundFunctionDeclarationStatement(FunctionScriptType* type, BoundVariableKey* key, BoundBlockStatement* block){ + BoundFunctionDeclarationStatement(std::shared_ptr type, BoundVariableKey* key, BoundBlockStatement* block){ _key = key; _block = shared_ptr(block); _type = type; @@ -18,7 +18,6 @@ public: ~BoundFunctionDeclarationStatement() final{ delete _key; - delete _type; } BoundStatementKind GetKind() final{ @@ -33,7 +32,7 @@ public: return _block; } - FunctionScriptType* GetType(){ + std::shared_ptr GetType(){ return _type; } }; diff --git a/src/Binder/BoundVariables/BoundScope.cpp b/src/Binder/BoundVariables/BoundScope.cpp index 5fb482f..738cf78 100644 --- a/src/Binder/BoundVariables/BoundScope.cpp +++ b/src/Binder/BoundVariables/BoundScope.cpp @@ -1,3 +1,5 @@ +#include + #include "BoundScope.hpp" @@ -70,16 +72,16 @@ BoundVariable *BoundScope::GetVariable(int scope, int identifier) { } } -VariableAssignment BoundScope::CreateExplicitLocal(int identifier, const ScriptType& type) { +VariableAssignment BoundScope::CreateExplicitLocal(int identifier, std::shared_ptr type) { auto scope = this->_localScope.at(this->_currentScope - 1); if (scope -> find(identifier) != scope -> end()){ return VariableAssignment(VariableAssignmentResult::ExplicitLocalVariableExists, nullptr); } - scope -> insert({identifier, new BoundVariable(type)}); + scope -> insert({identifier, new BoundVariable(std::move(type))}); return VariableAssignment(VariableAssignmentResult::Ok, new BoundVariableKey(identifier, this->_currentScope, true)); } -VariableAssignment BoundScope::AssignVariable(int identifier, const ScriptType& type) { +VariableAssignment BoundScope::AssignVariable(int identifier, const std::shared_ptr& type) { int exists = this->Exists(identifier); if (exists == -1){ // Creation diff --git a/src/Binder/BoundVariables/BoundScope.hpp b/src/Binder/BoundVariables/BoundScope.hpp index a0db2b2..2db5949 100644 --- a/src/Binder/BoundVariables/BoundScope.hpp +++ b/src/Binder/BoundVariables/BoundScope.hpp @@ -5,6 +5,7 @@ #include #include #include +#include #include "BoundVariable.hpp" #include "BoundVariableKey.hpp" #include "VariableAssigmentResult.hpp" @@ -26,8 +27,8 @@ public: int Exists(int key); BoundVariable* GetVariable(int scope, int identifier); - VariableAssignment CreateExplicitLocal(int identifier, const ScriptType& type); - VariableAssignment AssignVariable(int identifier, const ScriptType& type); + VariableAssignment CreateExplicitLocal(int identifier, std::shared_ptr type); + VariableAssignment AssignVariable(int identifier, const std::shared_ptr& type); int GetDeepestScope(){ return _deepestScope; diff --git a/src/Binder/BoundVariables/BoundVariable.hpp b/src/Binder/BoundVariables/BoundVariable.hpp index 786e366..ba5b125 100644 --- a/src/Binder/BoundVariables/BoundVariable.hpp +++ b/src/Binder/BoundVariables/BoundVariable.hpp @@ -2,17 +2,20 @@ #ifndef PORYGONLANG_BOUNDVARIABLE_HPP #define PORYGONLANG_BOUNDVARIABLE_HPP +#include #include "../../ScriptType.hpp" +using namespace std; + class BoundVariable{ - ScriptType _type; + std::shared_ptr _type; public: - explicit BoundVariable(const ScriptType& type) : _type(type){ + explicit BoundVariable(std::shared_ptr type) : _type(type){ } ~BoundVariable(){ } - ScriptType& GetType(){ + std::shared_ptr GetType(){ return _type; } }; diff --git a/src/Diagnostics/DiagnosticCode.hpp b/src/Diagnostics/DiagnosticCode.hpp index 3b8e088..9d49fe1 100644 --- a/src/Diagnostics/DiagnosticCode.hpp +++ b/src/Diagnostics/DiagnosticCode.hpp @@ -15,6 +15,9 @@ enum class DiagnosticCode{ NoUnaryOperationFound, CantAssignVariable, VariableNotFound, + ExpressionIsNotAFunction, + ParameterCountMismatch, + ParameterTypeMismatch, }; #endif //PORYGONLANG_DIAGNOSTICCODE_HPP diff --git a/src/Evaluator/EvalValues/EvalValue.hpp b/src/Evaluator/EvalValues/EvalValue.hpp index 164d887..61d3fb6 100644 --- a/src/Evaluator/EvalValues/EvalValue.hpp +++ b/src/Evaluator/EvalValues/EvalValue.hpp @@ -6,12 +6,13 @@ #include "../EvaluationException.hpp" #include #include +#include class EvalValue{ public: EvalValue() = default; virtual ~EvalValue() = default; - virtual ScriptType* GetType() = 0; + virtual std::shared_ptr GetType() = 0; virtual bool operator ==(EvalValue* b) = 0; @@ -37,22 +38,18 @@ public: class BooleanEvalValue : public EvalValue{ bool _value; - ScriptType* _type; + std::shared_ptr _type; public: explicit BooleanEvalValue(bool val){ _value = val; - _type = new ScriptType(TypeClass::Bool); + _type = std::make_shared(TypeClass::Bool); } EvalValue* Clone() final{ return new BooleanEvalValue(_value); } - ~BooleanEvalValue() final{ - delete _type; - } - - ScriptType* GetType() final{ + std::shared_ptr GetType() final{ return _type; }; diff --git a/src/Evaluator/EvalValues/NumericEvalValue.hpp b/src/Evaluator/EvalValues/NumericEvalValue.hpp index 3238da8..ce65fd9 100644 --- a/src/Evaluator/EvalValues/NumericEvalValue.hpp +++ b/src/Evaluator/EvalValues/NumericEvalValue.hpp @@ -11,13 +11,10 @@ class NumericEvalValue : public EvalValue{ virtual double GetFloatValue() = 0; protected: - ScriptType* _type; + std::shared_ptr _type; public: - ~NumericEvalValue() override{ - delete _type; - }; virtual const bool IsFloat() = 0; - ScriptType* GetType() override { + std::shared_ptr GetType() override { return _type; } @@ -34,7 +31,7 @@ class IntegerEvalValue : public NumericEvalValue{ double GetFloatValue() final{ throw EvaluationException("Attempting to retrieve float from int eval value."); } public: explicit IntegerEvalValue(long value){ - _type = new NumericScriptType(true, false); + _type = std::make_shared(true, false); _value = value; } const bool IsFloat() final{ @@ -62,7 +59,7 @@ class FloatEvalValue : public NumericEvalValue{ double GetFloatValue() final{return _value;} public: explicit FloatEvalValue(double value){ - _type = new NumericScriptType(true, true); + _type = std::make_shared(true, true); _value = value; } const bool IsFloat() final{ diff --git a/src/Evaluator/EvalValues/ScriptFunctionEvalValue.hpp b/src/Evaluator/EvalValues/ScriptFunctionEvalValue.hpp index 56976e6..8d4ca4d 100644 --- a/src/Evaluator/EvalValues/ScriptFunctionEvalValue.hpp +++ b/src/Evaluator/EvalValues/ScriptFunctionEvalValue.hpp @@ -1,3 +1,5 @@ +#include + #ifndef PORYGONLANG_SCRIPTFUNCTIONEVALVALUE_HPP #define PORYGONLANG_SCRIPTFUNCTIONEVALVALUE_HPP @@ -11,44 +13,32 @@ class ScriptFunctionEvalValue : public EvalValue{ std::shared_ptr _innerBlock; - FunctionScriptType _type; + std::shared_ptr _type; public: - explicit ScriptFunctionEvalValue(std::shared_ptr innerBlock, FunctionScriptType type) + explicit ScriptFunctionEvalValue(std::shared_ptr innerBlock, std::shared_ptr type) : _type(std::move(type)) { _innerBlock = std::move(innerBlock); } + std::shared_ptr GetType() final{ + return _type; + } + EvalValue* Clone() final{ return new ScriptFunctionEvalValue(_innerBlock, _type); } - ScriptType* GetType() final{ - return &_type; - }; - bool operator ==(EvalValue* b) final{ if (b->GetType()->GetClass() != TypeClass::Function) return false; return this->_innerBlock == ((ScriptFunctionEvalValue*)b)->_innerBlock; }; - EvalValue* EvaluateFunction(Evaluator* evaluator, const vector& parameters){ - auto parameterTypes = _type.GetParameterTypes(); - auto parameterKeys = _type.GetParameterKeys(); - auto scope = evaluator->GetScope(); - for (int i = 0; i < parameterTypes.size() && i < parameterKeys.size() && i < parameters.size(); i++){ - auto parameter = parameters[i]; - auto requiredType = parameterTypes[i]; - if (parameter->GetType() != requiredType.get()){ - throw EvaluationException("Passed wrong type to function."); - } - auto key = parameterKeys[i]; - scope->CreateVariable(key->GetScopeId(), key->GetIdentifier(), parameter->Clone()); - } - evaluator->EvaluateBlockStatement(_innerBlock.get()); - return nullptr; + std::shared_ptr GetInnerBlock(){ + return _innerBlock; } + }; diff --git a/src/Evaluator/EvalValues/StringEvalValue.hpp b/src/Evaluator/EvalValues/StringEvalValue.hpp index 2426b5b..fe75ac1 100644 --- a/src/Evaluator/EvalValues/StringEvalValue.hpp +++ b/src/Evaluator/EvalValues/StringEvalValue.hpp @@ -9,17 +9,14 @@ using namespace std; class StringEvalValue : public EvalValue{ string _value; - ScriptType* _type; + std::shared_ptr _type; public: explicit StringEvalValue(string s){ _value = move(s); - _type = new ScriptType(TypeClass::String); - } - ~StringEvalValue() final{ - delete _type; + _type = std::make_shared(TypeClass::String); } - ScriptType* GetType() final{ + std::shared_ptr GetType() final{ return _type; }; bool operator ==(EvalValue* b) final{ diff --git a/src/Evaluator/Evaluator.cpp b/src/Evaluator/Evaluator.cpp index 1adc7b9..973b75f 100644 --- a/src/Evaluator/Evaluator.cpp +++ b/src/Evaluator/Evaluator.cpp @@ -52,7 +52,7 @@ void Evaluator::EvaluateFunctionDeclarationStatement(BoundFunctionDeclarationSta auto type = statement->GetType(); auto key = statement->GetKey(); auto block = statement->GetBlock(); - auto value = new ScriptFunctionEvalValue(block, *type); + auto value = new ScriptFunctionEvalValue(block, type); if (key->IsCreation()){ this->_evaluationScope->CreateVariable(key->GetScopeId(), key->GetIdentifier(), value); } else{ @@ -66,6 +66,8 @@ EvalValue *Evaluator::EvaluateExpression(BoundExpression *expression) { case TypeClass ::Number: return this -> EvaluateIntegerExpression(expression); case TypeClass ::Bool: return this -> EvaluateBoolExpression(expression); case TypeClass ::String: return this -> EvaluateStringExpression(expression); + case TypeClass ::Function: return this->EvaluateFunctionExpression(expression); + case TypeClass ::Nil: return this->EvaluateNilExpression(expression); default: throw; } } @@ -81,6 +83,7 @@ NumericEvalValue* Evaluator::EvaluateIntegerExpression(BoundExpression *expressi case BoundExpressionKind::Unary: return this -> EvaluateIntegerUnary((BoundUnaryExpression*)expression); case BoundExpressionKind ::Binary: return this -> EvaluateIntegerBinary((BoundBinaryExpression*)expression); case BoundExpressionKind::Variable: return (NumericEvalValue*)this->GetVariable((BoundVariableExpression*)expression); + case BoundExpressionKind ::FunctionCall: return (NumericEvalValue*)this->EvaluateFunctionCallExpression(expression); case BoundExpressionKind ::LiteralString: case BoundExpressionKind ::LiteralBool: @@ -95,6 +98,7 @@ BooleanEvalValue* Evaluator::EvaluateBoolExpression(BoundExpression *expression) case BoundExpressionKind::Unary: return this -> EvaluateBooleanUnary((BoundUnaryExpression*)expression); case BoundExpressionKind::Binary: return this -> EvaluateBooleanBinary((BoundBinaryExpression*)expression); case BoundExpressionKind::Variable: return (BooleanEvalValue*)this->GetVariable((BoundVariableExpression*)expression); + case BoundExpressionKind ::FunctionCall: return (BooleanEvalValue*)this->EvaluateFunctionCallExpression(expression); case BoundExpressionKind::Bad: case BoundExpressionKind::LiteralInteger: @@ -112,7 +116,7 @@ StringEvalValue* Evaluator::EvaluateStringExpression(BoundExpression *expression case BoundExpressionKind::Binary: return this -> EvaluateStringBinary((BoundBinaryExpression*)expression); case BoundExpressionKind::Variable: return (StringEvalValue*)this->GetVariable((BoundVariableExpression*)expression); - + case BoundExpressionKind ::FunctionCall: return (StringEvalValue*)this->EvaluateFunctionCallExpression(expression); case BoundExpressionKind::Bad: case BoundExpressionKind::LiteralInteger: @@ -121,5 +125,47 @@ StringEvalValue* Evaluator::EvaluateStringExpression(BoundExpression *expression case BoundExpressionKind::Unary: throw; - }} + } +} +EvalValue* Evaluator::EvaluateFunctionExpression(BoundExpression * expression){ + switch (expression->GetKind()){ + case BoundExpressionKind ::Variable: return this->GetVariable((BoundVariableExpression*)expression); + default: throw; + } +} +EvalValue* Evaluator::EvaluateNilExpression(BoundExpression * expression){ + switch (expression->GetKind()){ + case BoundExpressionKind ::FunctionCall: + return this->EvaluateFunctionCallExpression(expression); + default: + return nullptr; + } +} + + +EvalValue* Evaluator::EvaluateFunctionCallExpression(BoundExpression* expression){ + auto functionCall = (BoundFunctionCallExpression*)expression; + auto function = (ScriptFunctionEvalValue*)this->EvaluateExpression(functionCall->GetFunctionExpression()); + auto boundParameters = functionCall->GetParameters(); + auto parameters = vector(boundParameters.size()); + for (int i = 0; i < boundParameters.size(); i++){ + parameters[i] = this->EvaluateExpression(boundParameters[i]); + } + + auto type = std::dynamic_pointer_cast(function->GetType()); + auto parameterTypes = type->GetParameterTypes(); + auto parameterKeys = type->GetParameterKeys(); + for (int i = 0; i < parameterTypes->size() && i < parameterKeys->size() && i < parameters.size(); i++){ + auto parameter = parameters[i]; + auto requiredType = parameterTypes->at(i); + if (*parameter->GetType() != requiredType.get()){ + throw EvaluationException("Passed wrong type to function."); + } + auto key = parameterKeys->at(i); + this->_evaluationScope->CreateVariable(key->GetScopeId(), key->GetIdentifier(), parameter->Clone()); + } + this->EvaluateBlockStatement(function->GetInnerBlock().get()); + return nullptr; + +} \ No newline at end of file diff --git a/src/Evaluator/Evaluator.hpp b/src/Evaluator/Evaluator.hpp index f6575aa..667587e 100644 --- a/src/Evaluator/Evaluator.hpp +++ b/src/Evaluator/Evaluator.hpp @@ -9,6 +9,7 @@ #include "EvalValues/EvalValue.hpp" #include "EvalValues/NumericEvalValue.hpp" #include "EvalValues/StringEvalValue.hpp" +#include "EvalValues/StringEvalValue.hpp" #include "EvaluationScope/EvaluationScope.hpp" using namespace boost; @@ -28,6 +29,8 @@ class Evaluator { NumericEvalValue* EvaluateIntegerExpression(BoundExpression* expression); BooleanEvalValue* EvaluateBoolExpression(BoundExpression* expression); StringEvalValue* EvaluateStringExpression(BoundExpression* expression); + EvalValue* EvaluateFunctionExpression(BoundExpression *expression); + EvalValue *EvaluateNilExpression(BoundExpression *expression); NumericEvalValue* EvaluateIntegerBinary(BoundBinaryExpression* expression); BooleanEvalValue *EvaluateBooleanBinary(BoundBinaryExpression *expression); @@ -35,6 +38,7 @@ class Evaluator { NumericEvalValue* EvaluateIntegerUnary(BoundUnaryExpression* expression); BooleanEvalValue *EvaluateBooleanUnary(BoundUnaryExpression *expression); + EvalValue *EvaluateFunctionCallExpression(BoundExpression *expression); EvalValue *GetVariable(BoundVariableExpression *expression); public: @@ -55,6 +59,8 @@ public: EvaluationScope* GetScope(){ return _evaluationScope; } + + }; diff --git a/src/Parser/ParsedExpressions/ParsedExpression.hpp b/src/Parser/ParsedExpressions/ParsedExpression.hpp index 8e863fc..e58aff7 100644 --- a/src/Parser/ParsedExpressions/ParsedExpression.hpp +++ b/src/Parser/ParsedExpressions/ParsedExpression.hpp @@ -225,9 +225,8 @@ public: FunctionCallExpression(ParsedExpression* function, vector parameters, unsigned int start, unsigned int length) : ParsedExpression(start, length){ _function = std::unique_ptr(function); - _parameters.reserve(parameters.size()); for (int i = 0; i < parameters.size(); i++){ - _parameters[i] = std::unique_ptr(parameters[i]); + _parameters.push_back(std::unique_ptr(parameters[i])); } } diff --git a/src/ScriptType.hpp b/src/ScriptType.hpp index 67d6fbf..0097ed4 100644 --- a/src/ScriptType.hpp +++ b/src/ScriptType.hpp @@ -1,5 +1,9 @@ #include +#include + +#include + #ifndef PORYGONLANG_SCRIPTTYPE_HPP #define PORYGONLANG_SCRIPTTYPE_HPP @@ -8,6 +12,8 @@ #include #include "Binder/BoundVariables/BoundVariableKey.hpp" +using namespace std; + enum class TypeClass{ Error, Nil, @@ -32,11 +38,18 @@ public: return _class; } - virtual bool operator ==(ScriptType b){ + virtual bool operator ==(const ScriptType& b){ return _class == b._class; }; - virtual bool operator !=(ScriptType b){ + virtual bool operator ==(ScriptType* b){ + return _class == b->_class; + }; + + virtual bool operator !=(const ScriptType& b){ + return ! (operator==(b)); + } + virtual bool operator !=(ScriptType* b){ return ! (operator==(b)); } }; @@ -62,26 +75,26 @@ public: }; class FunctionScriptType : public ScriptType{ - std::shared_ptr _returnType; - std::vector> _parameterTypes; - std::vector> _parameterKeys; + shared_ptr _returnType; + shared_ptr>> _parameterTypes; + shared_ptr>> _parameterKeys; public: - FunctionScriptType(std::shared_ptr returnType, std::vector> parameterTypes, - std::vector> parameterKeys) + FunctionScriptType(std::shared_ptr returnType, shared_ptr>> parameterTypes, + shared_ptr>> parameterKeys) : ScriptType(TypeClass::Function){ _returnType = std::move(returnType); _parameterTypes = std::move(parameterTypes); _parameterKeys = std::move(parameterKeys); } - ScriptType* GetReturnType(){ - return _returnType.get(); + shared_ptr GetReturnType(){ + return _returnType; } - std::vector> GetParameterTypes(){ + shared_ptr>> GetParameterTypes(){ return _parameterTypes; } - std::vector> GetParameterKeys(){ + shared_ptr>> GetParameterKeys(){ return _parameterKeys; } }; diff --git a/tests/integration/Functions.cpp b/tests/integration/Functions.cpp index 548fdb6..95b6114 100644 --- a/tests/integration/Functions.cpp +++ b/tests/integration/Functions.cpp @@ -12,4 +12,17 @@ TEST_CASE( "Define script function", "[integration]" ) { delete script; } +TEST_CASE( "Define script function and call", "[integration]" ) { + Script* script = Script::Create("function add(number a, number b) result = a + b end add(1, 2)"); + REQUIRE(!script->Diagnostics -> HasErrors()); + script->Evaluate(); + auto variable = script->GetVariable("add"); + REQUIRE(variable != nullptr); + REQUIRE(variable->GetType()->GetClass() == TypeClass::Function); + auto result = script->GetVariable("result"); + REQUIRE(result->GetType()->GetClass() == TypeClass::Number); + REQUIRE(result->EvaluateInteger() == 3); + delete script; +} + #endif