Make parsed statements constant during binding

This commit is contained in:
2019-06-13 18:14:59 +02:00
parent 601c4a3f89
commit 5910cbbfa9
8 changed files with 238 additions and 215 deletions

View File

@@ -1,9 +1,11 @@
#include <memory>
#include "Binder.hpp"
#include "../TableScriptType.hpp"
#include "BoundExpressions/BoundTableExpression.hpp"
#include <memory>
BoundScriptStatement *Binder::Bind(Script* script, ParsedScriptStatement *s, BoundScope* scriptScope) {
BoundScriptStatement *Binder::Bind(Script* script, const ParsedScriptStatement *s, BoundScope* scriptScope) {
auto binder = Binder();
binder._scriptData = script;
@@ -20,7 +22,7 @@ Binder::~Binder() {
delete _scope;
}
BoundStatement* Binder::BindStatement(ParsedStatement* statement){
BoundStatement* Binder::BindStatement(const ParsedStatement* statement){
switch (statement -> GetKind()) {
case ParsedStatementKind ::Script: throw; // This shouldn't happen.
case ParsedStatementKind ::Block: return this -> BindBlockStatement(statement);
@@ -34,7 +36,7 @@ BoundStatement* Binder::BindStatement(ParsedStatement* statement){
}
}
BoundStatement *Binder::BindBlockStatement(ParsedStatement *statement) {
BoundStatement *Binder::BindBlockStatement(const ParsedStatement *statement) {
auto statements = ((ParsedBlockStatement*)statement)->GetStatements();
vector<BoundStatement*> boundStatements (statements->size());
this->_scope->GoInnerScope();
@@ -45,12 +47,12 @@ BoundStatement *Binder::BindBlockStatement(ParsedStatement *statement) {
return new BoundBlockStatement(boundStatements);
}
BoundStatement *Binder::BindExpressionStatement(ParsedStatement *statement) {
BoundStatement *Binder::BindExpressionStatement(const ParsedStatement *statement) {
auto exp = ((ParsedExpressionStatement*)statement)->GetExpression();
return new BoundExpressionStatement(this -> BindExpression(exp));
}
BoundStatement* Binder::BindAssignmentStatement(ParsedStatement *statement){
BoundStatement* Binder::BindAssignmentStatement(const ParsedStatement *statement){
auto s = (ParsedAssignmentStatement*) statement;
auto boundExpression = this->BindExpression(s->GetExpression());
VariableAssignment assignment =
@@ -76,16 +78,16 @@ std::shared_ptr<ScriptType> ParseTypeIdentifier(HashedString s){
}
}
BoundStatement *Binder::BindFunctionDeclarationStatement(ParsedStatement *statement) {
BoundStatement *Binder::BindFunctionDeclarationStatement(const ParsedStatement *statement) {
auto functionStatement = (ParsedFunctionDeclarationStatement*) statement;
auto parameters = functionStatement->GetParameters();
auto parameterTypes = vector<shared_ptr<ScriptType>>(parameters.size());
auto parameterKeys = vector<shared_ptr<BoundVariableKey>>(parameters.size());
auto parameterTypes = vector<shared_ptr<ScriptType>>(parameters->size());
auto parameterKeys = vector<shared_ptr<BoundVariableKey>>(parameters->size());
auto scopeIndex = this->_scope->GetCurrentScope();
this->_scope->GoInnerScope();
for (int i = 0; i < parameters.size(); i++){
auto var = parameters[i];
for (int i = 0; i < parameters->size(); i++){
auto var = parameters -> at(i);
auto parsedType = ParseTypeIdentifier(var->GetType());
parameterTypes.at(i) = parsedType;
auto parameterAssignment = this->_scope->CreateExplicitLocal(var->GetIdentifier().GetHash(), parsedType);
@@ -114,7 +116,7 @@ BoundStatement *Binder::BindFunctionDeclarationStatement(ParsedStatement *statem
return new BoundFunctionDeclarationStatement(type, assignment.GetKey(), (BoundBlockStatement*)boundBlock);
}
BoundStatement *Binder::BindReturnStatement(ParsedStatement* statement){
BoundStatement *Binder::BindReturnStatement(const ParsedStatement* statement){
auto expression = ((ParsedReturnStatement*)statement)->GetExpression();
shared_ptr<ScriptType> currentReturnType;
if (this->_currentFunction == nullptr){
@@ -143,7 +145,7 @@ BoundStatement *Binder::BindReturnStatement(ParsedStatement* statement){
return new BoundReturnStatement(boundExpression);
}
BoundStatement *Binder::BindConditionalStatement(ParsedStatement* statement) {
BoundStatement *Binder::BindConditionalStatement(const ParsedStatement* statement) {
auto conditionalStatement = (ParsedConditionalStatement*)statement;
auto boundCondition = this -> BindExpression(conditionalStatement -> GetCondition());
if (boundCondition->GetType() -> GetClass() != TypeClass::Bool){
@@ -158,7 +160,7 @@ BoundStatement *Binder::BindConditionalStatement(ParsedStatement* statement) {
return new BoundConditionalStatement(boundCondition, boundBlock, elseStatement);
}
BoundExpression* Binder::BindExpression(ParsedExpression* expression){
BoundExpression* Binder::BindExpression(const ParsedExpression* expression){
switch (expression -> GetKind()){
case ParsedExpressionKind ::LiteralInteger:
return new BoundLiteralIntegerExpression(((LiteralIntegerExpression*)expression)->GetValue(), expression->GetStartPosition(), expression->GetLength());
@@ -193,7 +195,7 @@ BoundExpression* Binder::BindExpression(ParsedExpression* expression){
}
}
BoundExpression* Binder::BindVariableExpression(VariableExpression* expression){
BoundExpression* Binder::BindVariableExpression(const VariableExpression* expression){
auto key = expression->GetValue();
auto scope = this->_scope->Exists(key.GetHash());
if (scope == -1){
@@ -205,7 +207,7 @@ BoundExpression* Binder::BindVariableExpression(VariableExpression* expression){
return new BoundVariableExpression(new BoundVariableKey(key.GetHash(), scope, false), type, expression->GetStartPosition(), expression->GetLength());
}
BoundExpression* Binder::BindBinaryOperator(BinaryExpression* expression){
BoundExpression* Binder::BindBinaryOperator(const BinaryExpression* expression){
auto boundLeft = this -> BindExpression(expression->GetLeft());
auto boundRight = this -> BindExpression(expression->GetRight());
@@ -326,7 +328,7 @@ BoundExpression* Binder::BindBinaryOperator(BinaryExpression* expression){
return new BoundBadExpression(expression->GetStartPosition(), expression->GetLength());
}
BoundExpression* Binder::BindUnaryOperator(UnaryExpression* expression){
BoundExpression* Binder::BindUnaryOperator(const UnaryExpression* expression){
auto operand = this -> BindExpression(expression->GetOperand());
auto operandType = operand -> GetType();
switch (expression->GetOperatorKind()){
@@ -359,7 +361,7 @@ BoundExpression* Binder::BindUnaryOperator(UnaryExpression* expression){
}
BoundExpression* Binder::BindFunctionCall(FunctionCallExpression* expression){
BoundExpression* Binder::BindFunctionCall(const FunctionCallExpression* expression){
auto functionExpression = BindExpression(expression->GetFunction());
auto type = functionExpression->GetType();
if (type->GetClass() != TypeClass::Function){
@@ -370,14 +372,14 @@ BoundExpression* Binder::BindFunctionCall(FunctionCallExpression* expression){
auto functionType = std::dynamic_pointer_cast<FunctionScriptType>(type);
auto parameterTypes = functionType->GetParameterTypes();
auto givenParameters = expression->GetParameters();
if (parameterTypes.size() != givenParameters.size()){
if (parameterTypes.size() != givenParameters->size()){
this->_scriptData->Diagnostics->LogError(DiagnosticCode::ParameterCountMismatch, expression->GetStartPosition(),
expression->GetLength());
return new BoundBadExpression(expression->GetStartPosition(), expression->GetLength());
}
vector<BoundExpression*> boundParameters = vector<BoundExpression*>(givenParameters.size());
for (int i = 0; i < givenParameters.size(); i++){
auto parameter = givenParameters[i];
vector<BoundExpression*> boundParameters = vector<BoundExpression*>(givenParameters->size());
for (int i = 0; i < givenParameters->size(); i++){
auto parameter = givenParameters -> at(i);
auto boundParameter = this -> BindExpression(parameter);
if (boundParameter->GetType().get()->operator!=(parameterTypes.at(i).get())){
this->_scriptData->Diagnostics->LogError(DiagnosticCode::ParameterTypeMismatch, parameter->GetStartPosition(),
@@ -391,7 +393,7 @@ BoundExpression* Binder::BindFunctionCall(FunctionCallExpression* expression){
expression->GetStartPosition(), expression->GetLength());
}
BoundExpression *Binder::BindIndexExpression(IndexExpression *expression) {
BoundExpression *Binder::BindIndexExpression(const IndexExpression *expression) {
auto indexer = this->BindExpression(expression->GetIndexer());
auto index = this->BindExpression(expression->GetIndex());
@@ -404,15 +406,15 @@ BoundExpression *Binder::BindIndexExpression(IndexExpression *expression) {
return new BoundIndexExpression(indexer, index, resultType, expression->GetStartPosition(), expression->GetLength());
}
BoundExpression* Binder::BindNumericalTableExpression(ParsedNumericalTableExpression* expression){
BoundExpression* Binder::BindNumericalTableExpression(const ParsedNumericalTableExpression* expression){
auto expressions = expression->GetExpressions();
auto boundExpressions = vector<BoundExpression*>(expressions.size());
auto boundExpressions = vector<BoundExpression*>(expressions-> size());
shared_ptr<ScriptType> valueType = nullptr;
if (!boundExpressions.empty()){
boundExpressions[0] = this -> BindExpression(expressions[0]);
boundExpressions[0] = this -> BindExpression(expressions -> at(0));
valueType = boundExpressions[0] -> GetType();
for (int i = 1; i < expressions.size(); i++){
boundExpressions[i] = this -> BindExpression(expressions[i]);
for (int i = 1; i < expressions->size(); i++){
boundExpressions[i] = this -> BindExpression(expressions -> at(i));
if (boundExpressions[i] -> GetType().get()->operator!=(valueType.get())){
this->_scriptData->Diagnostics->LogError(DiagnosticCode::InvalidTableValueType, boundExpressions[i]->GetStartPosition(),
boundExpressions[i]->GetLength());
@@ -426,7 +428,7 @@ BoundExpression* Binder::BindNumericalTableExpression(ParsedNumericalTableExpres
return new BoundNumericalTableExpression(boundExpressions, tableType, expression->GetStartPosition(), expression->GetLength());
}
BoundExpression *Binder::BindTableExpression(ParsedTableExpression *expression) {
BoundExpression *Binder::BindTableExpression(const ParsedTableExpression *expression) {
auto tableScope = new unordered_map<int, BoundVariable*>();
auto innerScope = new BoundScope(tableScope);
auto currentScope = this -> _scope;
@@ -434,7 +436,7 @@ BoundExpression *Binder::BindTableExpression(ParsedTableExpression *expression)
auto block = this -> BindBlockStatement(expression -> GetBlock());
this -> _scope = currentScope;
auto tableType = shared_ptr<TableScriptType>(new TableScriptType(tableScope, innerScope->GetLocalVariableCount()));
auto tableType = std::make_shared<TableScriptType>(tableScope, innerScope->GetLocalVariableCount());
delete innerScope;
return new BoundTableExpression((BoundBlockStatement*)block, tableType, expression->GetStartPosition(), expression->GetLength());

View File

@@ -16,24 +16,24 @@ class Binder {
~Binder();
BoundStatement *BindStatement(ParsedStatement *statement);
BoundStatement *BindBlockStatement(ParsedStatement *statement);
BoundStatement *BindExpressionStatement(ParsedStatement *statement);
BoundStatement *BindAssignmentStatement(ParsedStatement *statement);
BoundStatement *BindFunctionDeclarationStatement(ParsedStatement * statement);
BoundStatement *BindReturnStatement(ParsedStatement *statement);
BoundStatement *BindConditionalStatement(ParsedStatement *statement);
BoundStatement *BindStatement(const ParsedStatement *statement);
BoundStatement *BindBlockStatement(const ParsedStatement *statement);
BoundStatement *BindExpressionStatement(const ParsedStatement *statement);
BoundStatement *BindAssignmentStatement(const ParsedStatement *statement);
BoundStatement *BindFunctionDeclarationStatement(const ParsedStatement * statement);
BoundStatement *BindReturnStatement(const ParsedStatement *statement);
BoundStatement *BindConditionalStatement(const ParsedStatement *statement);
BoundExpression *BindExpression(ParsedExpression *expression);
BoundExpression *BindVariableExpression(VariableExpression *expression);
BoundExpression *BindBinaryOperator(BinaryExpression *expression);
BoundExpression *BindUnaryOperator(UnaryExpression *expression);
BoundExpression *BindFunctionCall(FunctionCallExpression *expression);
BoundExpression *BindIndexExpression(IndexExpression *expression);
BoundExpression *BindNumericalTableExpression(ParsedNumericalTableExpression *expression);
BoundExpression *BindTableExpression(ParsedTableExpression * expression);
BoundExpression *BindExpression(const ParsedExpression *expression);
BoundExpression *BindVariableExpression(const VariableExpression *expression);
BoundExpression *BindBinaryOperator(const BinaryExpression *expression);
BoundExpression *BindUnaryOperator(const UnaryExpression *expression);
BoundExpression *BindFunctionCall(const FunctionCallExpression *expression);
BoundExpression *BindIndexExpression(const IndexExpression *expression);
BoundExpression *BindNumericalTableExpression(const ParsedNumericalTableExpression *expression);
BoundExpression *BindTableExpression(const ParsedTableExpression * expression);
public:
static BoundScriptStatement* Bind(Script* script, ParsedScriptStatement* s, BoundScope* scriptScope);
static BoundScriptStatement* Bind(Script* script, const ParsedScriptStatement* s, BoundScope* scriptScope);
};