#include #include "Binder.hpp" #include "../ScriptTypes/TableScriptType.hpp" #include "BoundExpressions/BoundTableExpression.hpp" #include "BoundExpressions/BoundFunctionCallExpression.hpp" #include "BoundExpressions/BoundRequireExpression.hpp" #include "../UserData/UserDataScriptType.hpp" using namespace Porygon::Parser; namespace Porygon::Binder { BoundScriptStatement *Binder::Bind(Script *script, const ParsedScriptStatement *s, BoundScope *scriptScope) { auto binder = Binder(); binder._scriptData = script; binder._scope = scriptScope; auto statements = s->GetStatements(); vector boundStatements(statements->size()); for (size_t i = 0; i < statements->size(); i++) { boundStatements[i] = binder.BindStatement(statements->at(i)); } return new BoundScriptStatement(boundStatements, scriptScope->GetLocalVariableCount()); } Binder::~Binder() { delete _scope; } BoundStatement *Binder::BindStatement(const ParsedStatement *statement) { switch (statement->GetKind()) { case ParsedStatementKind::Script: throw; // This shouldn't happen. case ParsedStatementKind::Block: return this->BindBlockStatement(statement); case ParsedStatementKind::Expression: return this->BindExpressionStatement(statement); case ParsedStatementKind::Assignment: return this->BindAssignmentStatement(statement); case ParsedStatementKind::IndexAssignment: return this->BindIndexAssignmentStatement(statement); case ParsedStatementKind::FunctionDeclaration: return this->BindFunctionDeclarationStatement(statement); case ParsedStatementKind::Return: return this->BindReturnStatement(statement); case ParsedStatementKind::Conditional: return this->BindConditionalStatement(statement); case ParsedStatementKind::NumericalFor: return this->BindNumericalForStatement(statement); case ParsedStatementKind::GenericFor: return this->BindGenericForStatement(statement); case ParsedStatementKind::While: return this->BindWhileStatement(statement); case ParsedStatementKind::Break: //TODO: Validate we're in a loop return new BoundBreakStatement(); case ParsedStatementKind::Bad: return new BoundBadStatement(); } throw "unreachable"; } BoundStatement *Binder::BindBlockStatement(const ParsedStatement *statement) { auto statements = ((ParsedBlockStatement *) statement)->GetStatements(); vector boundStatements(statements->size()); this->_scope->GoInnerScope(); for (size_t i = 0; i < statements->size(); i++) { boundStatements[i] = this->BindStatement(statements->at(i)); } this->_scope->GoOuterScope(); return new BoundBlockStatement(boundStatements); } BoundStatement *Binder::BindExpressionStatement(const ParsedStatement *statement) { auto exp = ((ParsedExpressionStatement *) statement)->GetExpression(); return new BoundExpressionStatement(this->BindExpression(exp)); } BoundStatement *Binder::BindAssignmentStatement(const ParsedStatement *statement) { auto s = (ParsedAssignmentStatement *) statement; auto boundExpression = this->BindExpression(s->GetExpression()); VariableAssignment assignment = s->IsLocal() ? this->_scope->CreateExplicitLocal(s->GetIdentifier(), boundExpression->GetType()) : this->_scope->AssignVariable(s->GetIdentifier(), boundExpression->GetType()); if (assignment.GetResult() == VariableAssignmentResult::Ok) { auto key = assignment.GetKey(); return new BoundAssignmentStatement(key, boundExpression); } else { this->_scriptData->Diagnostics->LogError(Diagnostics::DiagnosticCode::CantAssignVariable, statement->GetStartPosition(), statement->GetLength()); return new BoundBadStatement(); } } BoundStatement *Binder::BindIndexAssignmentStatement(const ParsedStatement *statement) { auto s = (ParsedIndexAssignmentStatement *) statement; auto indexExp = s->GetIndexExpression(); const BoundExpression *indexable; if (indexExp->GetKind() == ParsedExpressionKind::Indexer) { indexable = this->BindIndexExpression((IndexExpression *) indexExp, true); } else { indexable = this->BindPeriodIndexExpression((PeriodIndexExpression *) indexExp, true); } auto valueExpression = this->BindExpression(s->GetValueExpression()); auto boundIndexType = indexable->GetType(); if (boundIndexType->GetClass() != TypeClass::Error && boundIndexType->operator!=(valueExpression->GetType())) { this->_scriptData->Diagnostics->LogError(Diagnostics::DiagnosticCode::InvalidTableValueType, statement->GetStartPosition(), statement->GetLength()); return new BoundBadStatement(); } return new BoundIndexAssignmentStatement(indexable, valueExpression); } std::shared_ptr ParseTypeIdentifier(const HashedString *s) { auto hash = s->GetHash(); switch (hash) { case HashedString::ConstHash("number"): return NumericScriptType::Unaware; case HashedString::ConstHash("bool"): return ScriptType::BoolType; case HashedString::ConstHash("string"): return StringScriptType::Dynamic; case HashedString::ConstHash(("any")): return ScriptType::AnyType; case HashedString::ConstHash("table"): return make_shared(); default: if (!UserData::UserDataStorage::HasUserDataType(hash)) { return nullptr; } return std::make_shared(hash); } } BoundStatement *Binder::BindFunctionDeclarationStatement(const ParsedStatement *statement) { auto functionStatement = (ParsedFunctionDeclarationStatement *) statement; auto parameters = functionStatement->GetParameters(); auto parameterTypes = vector>(parameters->size()); auto parameterKeys = vector>(parameters->size()); this->_scope->GoInnerScope(); for (size_t i = 0; i < parameters->size(); i++) { auto var = parameters->at(i); auto parsedType = ParseTypeIdentifier(var->GetType()); if (parsedType == nullptr) { this->_scriptData->Diagnostics->LogError(Diagnostics::DiagnosticCode::InvalidTypeName, statement->GetStartPosition(), statement->GetLength()); return new BoundBadStatement(); } parameterTypes.at(i) = parsedType; auto parameterAssignment = this->_scope->CreateExplicitLocal(*var->GetIdentifier(), parsedType); if (parameterAssignment.GetResult() == VariableAssignmentResult::Ok) { parameterKeys.at(i) = std::shared_ptr(parameterAssignment.GetKey()); } else { //TODO: log error continue; } } auto identifier = functionStatement->GetIdentifier(); auto returnType = ScriptType::NilType; auto option = new ScriptFunctionOption(returnType, parameterTypes, parameterKeys); this->_currentFunction = option; shared_ptr type; auto scope = this->_scope->Exists(identifier); const BoundVariableKey *assignmentKey; if (scope >= 0) { auto var = this->_scope->GetVariable(scope, identifier); auto varType = var->GetType(); if (varType->GetClass() != TypeClass::Function) { this->_scriptData->Diagnostics->LogError(Diagnostics::DiagnosticCode::CantAssignVariable, statement->GetStartPosition(), statement->GetLength()); } type = dynamic_pointer_cast(varType); type->RegisterFunctionOption(option); assignmentKey = new BoundVariableKey(identifier, scope, false, type); } else { type = make_shared(); type->RegisterFunctionOption(option); auto assignment = this->_scope->AssignVariable(identifier, type); if (assignment.GetResult() != VariableAssignmentResult::Ok) { this->_scriptData->Diagnostics->LogError(Diagnostics::DiagnosticCode::CantAssignVariable, statement->GetStartPosition(), statement->GetLength()); return new BoundBadStatement(); } assignmentKey = assignment.GetKey(); } auto boundBlock = dynamic_cast(this->BindBlockStatement(functionStatement->GetBlock())); this->_scope->GoOuterScope(); this->_currentFunction = nullptr; return new BoundFunctionDeclarationStatement(type, assignmentKey, boundBlock); } BoundStatement *Binder::BindReturnStatement(const ParsedStatement *statement) { auto expression = ((ParsedReturnStatement *) statement)->GetExpression(); shared_ptr currentReturnType; if (this->_currentFunction == nullptr) { currentReturnType = this->_scriptData->GetReturnType(); } else { currentReturnType = this->_currentFunction->GetReturnType(); } if (expression == nullptr && (currentReturnType != nullptr && currentReturnType->GetClass() != TypeClass::Nil)) { this->_scriptData->Diagnostics->LogError(Diagnostics::DiagnosticCode::InvalidReturnType, statement->GetStartPosition(), statement->GetLength()); return new BoundBadStatement(); } else if (expression == nullptr) { currentReturnType = ScriptType::NilType; return new BoundReturnStatement(nullptr); } auto boundExpression = this->BindExpression(expression); auto expresionType = boundExpression->GetType(); if (currentReturnType == nullptr || currentReturnType->GetClass() == TypeClass::Nil) { if (this->_currentFunction == nullptr) { this->_scriptData->SetReturnType(expresionType); } else { this->_currentFunction->SetReturnType(expresionType); } return new BoundReturnStatement(boundExpression); } if (currentReturnType.get()->operator!=(expresionType)) { this->_scriptData->Diagnostics->LogError(Diagnostics::DiagnosticCode::InvalidReturnType, statement->GetStartPosition(), statement->GetLength()); return new BoundBadStatement(); } return new BoundReturnStatement(boundExpression); } BoundStatement *Binder::BindConditionalStatement(const ParsedStatement *statement) { auto conditionalStatement = (ParsedConditionalStatement *) statement; auto boundCondition = this->BindExpression(conditionalStatement->GetCondition()); if (boundCondition->GetType()->GetClass() != TypeClass::Bool) { this->_scriptData->Diagnostics->LogError(Diagnostics::DiagnosticCode::ConditionNotABool, statement->GetStartPosition(), statement->GetLength()); return new BoundBadStatement(); } auto boundBlock = this->BindStatement(conditionalStatement->GetBlock()); BoundStatement *elseStatement = nullptr; if (conditionalStatement->GetElseStatement() != nullptr) { elseStatement = this->BindStatement(conditionalStatement->GetElseStatement()); } return new BoundConditionalStatement(boundCondition, boundBlock, elseStatement); } BoundStatement *Binder::BindNumericalForStatement(const ParsedStatement *statement) { auto forStatement = (ParsedNumericalForStatement *) statement; auto identifier = forStatement->GetIdentifier(); auto start = this->BindExpression(forStatement->GetStart()); auto end = this->BindExpression(forStatement->GetEnd()); auto parsedStep = forStatement->GetStep(); BoundExpression *step = nullptr; if (parsedStep != nullptr) { step = this->BindExpression(parsedStep); } if (start->GetType()->GetClass() != TypeClass::Number) { this->_scriptData->Diagnostics->LogError(Diagnostics::DiagnosticCode::NumericalForArgumentNotANumber, start->GetStartPosition(), start->GetLength()); return new BoundBadStatement(); } if (end->GetType()->GetClass() != TypeClass::Number) { this->_scriptData->Diagnostics->LogError(Diagnostics::DiagnosticCode::NumericalForArgumentNotANumber, end->GetStartPosition(), end->GetLength()); return new BoundBadStatement(); } if (step != nullptr && step->GetType()->GetClass() != TypeClass::Number) { this->_scriptData->Diagnostics->LogError(Diagnostics::DiagnosticCode::NumericalForArgumentNotANumber, step->GetStartPosition(), step->GetLength()); return new BoundBadStatement(); } this->_scope->GoInnerScope(); auto variableKey = this->_scope->CreateExplicitLocal(identifier, NumericScriptType::AwareInt); if (variableKey.GetResult() != VariableAssignmentResult::Ok) { this->_scriptData->Diagnostics->LogError(Diagnostics::DiagnosticCode::CantAssignVariable, statement->GetStartPosition(), statement->GetLength()); return new BoundBadStatement(); } auto block = this->BindBlockStatement(forStatement->GetBlock()); this->_scope->GoOuterScope(); return new BoundNumericalForStatement(variableKey.GetKey(), start, end, step, block); } BoundStatement *Binder::BindGenericForStatement(const ParsedStatement *statement) { auto genericFor = (ParsedGenericForStatement *) statement; auto boundIterator = BindExpression(genericFor->GetIteratorExpression()); const auto &itType = boundIterator->GetType(); if (!itType->CanBeIterated()) { this->_scriptData->Diagnostics->LogError(Diagnostics::DiagnosticCode::CantIterateExpression, statement->GetStartPosition(), statement->GetLength()); return new BoundBadStatement(); } auto keyType = itType->GetIteratorKeyType(); auto keyIdentifier = genericFor->GetKeyIdentifier(); this->_scope->GoInnerScope(); auto keyVariableAssignment = this->_scope->CreateExplicitLocal(keyIdentifier, keyType); if (keyVariableAssignment.GetResult() != VariableAssignmentResult::Ok) { this->_scriptData->Diagnostics->LogError(Diagnostics::DiagnosticCode::CantAssignVariable, statement->GetStartPosition(), statement->GetLength()); return new BoundBadStatement(); } auto keyVariable = keyVariableAssignment.GetKey(); auto valueIdentifier = genericFor->GetValueIdentifier(); auto isValueVariableDefined = valueIdentifier.GetHash() != 0; const BoundVariableKey *valueVariable = nullptr; if (isValueVariableDefined) { auto valueType = itType->GetIndexedType(keyType.get()); auto valueVariableAssignment = this->_scope->CreateExplicitLocal(valueIdentifier, valueType); if (valueVariableAssignment.GetResult() != VariableAssignmentResult::Ok) { this->_scriptData->Diagnostics->LogError(Diagnostics::DiagnosticCode::CantAssignVariable, statement->GetStartPosition(), statement->GetLength()); return new BoundBadStatement(); } valueVariable = valueVariableAssignment.GetKey(); } auto boundBlock = this->BindBlockStatement(genericFor->GetBlock()); this->_scope->GoOuterScope(); return new BoundGenericForStatement(keyVariable, valueVariable, boundIterator, boundBlock); } BoundStatement *Binder::BindWhileStatement(const ParsedStatement *statement) { auto whileStatement = (ParsedWhileStatement *) statement; auto boundCondition = this->BindExpression(whileStatement->GetCondition()); if (boundCondition->GetType()->GetClass() != TypeClass::Bool) { this->_scriptData->Diagnostics->LogError(Diagnostics::DiagnosticCode::ConditionNotABool, statement->GetStartPosition(), statement->GetLength()); return new BoundBadStatement(); } auto boundBlock = this->BindBlockStatement(whileStatement->GetBlock()); return new BoundWhileStatement(boundCondition, boundBlock); } ///////////////// // Expressions // ///////////////// BoundExpression *Binder::BindExpression(const ParsedExpression *expression) { switch (expression->GetKind()) { case ParsedExpressionKind::LiteralInteger: return new BoundLiteralIntegerExpression(((LiteralIntegerExpression *) expression)->GetValue(), expression->GetStartPosition(), expression->GetLength()); case ParsedExpressionKind::LiteralFloat: return new BoundLiteralFloatExpression(((LiteralFloatExpression *) expression)->GetValue(), expression->GetStartPosition(), expression->GetLength()); case ParsedExpressionKind::LiteralString: return new BoundLiteralStringExpression(((LiteralStringExpression *) expression)->GetValue(), expression->GetStartPosition(), expression->GetLength()); case ParsedExpressionKind::LiteralBool: return new BoundLiteralBoolExpression(((LiteralBoolExpression *) expression)->GetValue(), expression->GetStartPosition(), expression->GetLength()); case ParsedExpressionKind ::Nil: return new BoundNilExpression(expression->GetStartPosition(), expression->GetLength()); case ParsedExpressionKind::Variable: return this->BindVariableExpression((VariableExpression *) expression); case ParsedExpressionKind::Binary: return this->BindBinaryOperator((BinaryExpression *) expression); case ParsedExpressionKind::Unary: return this->BindUnaryOperator((UnaryExpression *) expression); case ParsedExpressionKind::Parenthesized: return BindExpression(((ParenthesizedExpression *) expression)->GetInnerExpression()); case ParsedExpressionKind::FunctionCall: return this->BindFunctionCall((FunctionCallExpression *) expression); case ParsedExpressionKind::Indexer: return this->BindIndexExpression((IndexExpression *) expression, false); case ParsedExpressionKind::PeriodIndexer: return this->BindPeriodIndexExpression((PeriodIndexExpression *) expression, false); case ParsedExpressionKind::NumericalTable: return this->BindNumericalTableExpression((ParsedNumericalTableExpression *) expression); case ParsedExpressionKind::Table: return this->BindTableExpression((ParsedTableExpression *) expression); case ParsedExpressionKind::Bad: return new BoundBadExpression(expression->GetStartPosition(), expression->GetLength()); } throw; } BoundExpression *Binder::BindVariableExpression(const VariableExpression *expression) { auto key = expression->GetValue(); auto scope = this->_scope->Exists(key); if (scope == -1) { this->_scriptData->Diagnostics->LogError(Diagnostics::DiagnosticCode::VariableNotFound, expression->GetStartPosition(), expression->GetLength()); return new BoundBadExpression(expression->GetStartPosition(), expression->GetLength()); } auto var = this->_scope->GetVariable(scope, key); auto type = var->GetType(); return new BoundVariableExpression(new BoundVariableKey(key, scope, false, type), type, expression->GetStartPosition(), expression->GetLength()); } BoundExpression *Binder::BindBinaryOperator(const BinaryExpression *expression) { auto boundLeft = this->BindExpression(expression->GetLeft()); auto boundRight = this->BindExpression(expression->GetRight()); auto boundLeftType = boundLeft->GetType(); auto boundRightType = boundRight->GetType(); auto kind = expression->GetOperatorKind(); if (boundLeftType->GetClass() == TypeClass::UserData){ auto ud = dynamic_pointer_cast(boundLeftType); auto op = ud->GetUserData()->Get()->GetBinaryOperation(kind, boundRightType); if (op == nullptr){ if (kind != BinaryOperatorKind::Equality && kind != BinaryOperatorKind::Inequality){ this->_scriptData->Diagnostics->LogError(Diagnostics::DiagnosticCode::NoBinaryOperationFound, expression->GetStartPosition(), expression->GetLength()); return new BoundBadExpression(expression->GetStartPosition(), expression->GetLength()); } } else{ return new BoundUserdataBinaryExpression(boundLeft, boundRight, op, op->GetReturnType(), expression->GetStartPosition(), expression->GetLength()); } } switch (kind) { case BinaryOperatorKind::Addition: if (boundLeftType->GetClass() == TypeClass::Number && boundRightType->GetClass() == TypeClass::Number) { auto leftNumeric = std::static_pointer_cast(boundLeftType); auto rightNumeric = std::static_pointer_cast(boundRightType); if (leftNumeric->IsAwareOfFloat() && rightNumeric->IsAwareOfFloat()) { return new BoundBinaryExpression(boundLeft, boundRight, BoundBinaryOperation::Addition, NumericScriptType::ResolveType(true, leftNumeric->IsFloat() || rightNumeric->IsFloat()), expression->GetStartPosition(), expression->GetLength()); } else { return new BoundBinaryExpression(boundLeft, boundRight, BoundBinaryOperation::Addition, NumericScriptType::Unaware, expression->GetStartPosition(), expression->GetLength()); } } else if (boundLeftType->GetClass() == TypeClass::String) { return new BoundBinaryExpression(boundLeft, boundRight, BoundBinaryOperation::Concatenation, StringScriptType::Dynamic, expression->GetStartPosition(), expression->GetLength()); } break; case BinaryOperatorKind::Subtraction: if (boundLeftType->GetClass() == TypeClass::Number && boundRightType->GetClass() == TypeClass::Number) { auto leftNumeric = std::dynamic_pointer_cast(boundLeftType); auto rightNumeric = std::dynamic_pointer_cast(boundRightType); if (leftNumeric->IsAwareOfFloat() && rightNumeric->IsAwareOfFloat()) { return new BoundBinaryExpression(boundLeft, boundRight, BoundBinaryOperation::Subtraction, NumericScriptType::ResolveType(true, leftNumeric->IsFloat() || rightNumeric->IsFloat()), expression->GetStartPosition(), expression->GetLength()); } else { return new BoundBinaryExpression(boundLeft, boundRight, BoundBinaryOperation::Subtraction, NumericScriptType::Unaware, expression->GetStartPosition(), expression->GetLength()); } } break; case BinaryOperatorKind::Multiplication: if (boundLeftType->GetClass() == TypeClass::Number && boundRightType->GetClass() == TypeClass::Number) { auto leftNumeric = std::dynamic_pointer_cast(boundLeftType); auto rightNumeric = std::dynamic_pointer_cast(boundRightType); if (leftNumeric->IsAwareOfFloat() && rightNumeric->IsAwareOfFloat()) { return new BoundBinaryExpression(boundLeft, boundRight, BoundBinaryOperation::Multiplication, NumericScriptType::ResolveType(true, leftNumeric->IsFloat() || rightNumeric->IsFloat()), expression->GetStartPosition(), expression->GetLength()); } else { return new BoundBinaryExpression(boundLeft, boundRight, BoundBinaryOperation::Multiplication, NumericScriptType::Unaware, expression->GetStartPosition(), expression->GetLength()); } } break; case BinaryOperatorKind::Division: if (boundLeftType->GetClass() == TypeClass::Number && boundRightType->GetClass() == TypeClass::Number) { auto leftNumeric = std::dynamic_pointer_cast(boundLeftType); auto rightNumeric = std::dynamic_pointer_cast(boundRightType); if (leftNumeric->IsAwareOfFloat() && rightNumeric->IsAwareOfFloat()) { return new BoundBinaryExpression(boundLeft, boundRight, BoundBinaryOperation::Division, NumericScriptType::ResolveType(true, leftNumeric->IsFloat() || rightNumeric->IsFloat()), expression->GetStartPosition(), expression->GetLength()); } else { return new BoundBinaryExpression(boundLeft, boundRight, BoundBinaryOperation::Division, NumericScriptType::Unaware, expression->GetStartPosition(), expression->GetLength()); } } break; case BinaryOperatorKind::Equality: return new BoundBinaryExpression(boundLeft, boundRight, BoundBinaryOperation::Equality, ScriptType::BoolType, expression->GetStartPosition(), expression->GetLength()); case BinaryOperatorKind::Inequality: return new BoundBinaryExpression(boundLeft, boundRight, BoundBinaryOperation::Inequality, ScriptType::BoolType, expression->GetStartPosition(), expression->GetLength()); case BinaryOperatorKind::Less: if (boundLeft->GetType()->GetClass() == TypeClass::Number && boundRight->GetType()->GetClass() == TypeClass::Number) { return new BoundBinaryExpression(boundLeft, boundRight, BoundBinaryOperation::LessThan, ScriptType::BoolType, expression->GetStartPosition(), expression->GetLength()); } case BinaryOperatorKind::LessOrEquals: if (boundLeft->GetType()->GetClass() == TypeClass::Number && boundRight->GetType()->GetClass() == TypeClass::Number) { return new BoundBinaryExpression(boundLeft, boundRight, BoundBinaryOperation::LessThanEquals, ScriptType::BoolType, expression->GetStartPosition(), expression->GetLength()); } case BinaryOperatorKind::Greater: if (boundLeft->GetType()->GetClass() == TypeClass::Number && boundRight->GetType()->GetClass() == TypeClass::Number) { return new BoundBinaryExpression(boundLeft, boundRight, BoundBinaryOperation::GreaterThan, ScriptType::BoolType, expression->GetStartPosition(), expression->GetLength()); } case BinaryOperatorKind::GreaterOrEquals: if (boundLeft->GetType()->GetClass() == TypeClass::Number && boundRight->GetType()->GetClass() == TypeClass::Number) { return new BoundBinaryExpression(boundLeft, boundRight, BoundBinaryOperation::GreaterThanEquals, ScriptType::BoolType, expression->GetStartPosition(), expression->GetLength()); } case BinaryOperatorKind::LogicalAnd: if (boundLeftType->GetClass() == TypeClass::Bool && boundRightType->GetClass() == TypeClass::Bool) return new BoundBinaryExpression(boundLeft, boundRight, BoundBinaryOperation::LogicalAnd, ScriptType::BoolType, expression->GetStartPosition(), expression->GetLength()); break; case BinaryOperatorKind::LogicalOr: if (boundLeftType->GetClass() == TypeClass::Bool && boundRightType->GetClass() == TypeClass::Bool) return new BoundBinaryExpression(boundLeft, boundRight, BoundBinaryOperation::LogicalOr, ScriptType::BoolType, expression->GetStartPosition(), expression->GetLength()); break; } this->_scriptData->Diagnostics->LogError(Diagnostics::DiagnosticCode::NoBinaryOperationFound, expression->GetStartPosition(), expression->GetLength()); return new BoundBadExpression(expression->GetStartPosition(), expression->GetLength()); } BoundExpression *Binder::BindUnaryOperator(const UnaryExpression *expression) { auto operand = this->BindExpression(expression->GetOperand()); auto operandType = operand->GetType(); switch (expression->GetOperatorKind()) { case UnaryOperatorKind::Identity: if (operandType->GetClass() == TypeClass::Number) { // Identity won't change anything during evaluation, so just return the inner operand. return operand; } break; case UnaryOperatorKind::Negation: if (operandType->GetClass() == TypeClass::Number) { auto innerType = std::dynamic_pointer_cast(operandType); return new BoundUnaryExpression(operand, BoundUnaryOperation::Negation, NumericScriptType::ResolveType( innerType.get()->IsAwareOfFloat(), innerType.get()->IsFloat()), expression->GetStartPosition(), expression->GetLength()); } break; case UnaryOperatorKind::LogicalNegation: if (operandType->GetClass() == TypeClass::Bool) { return new BoundUnaryExpression(operand, BoundUnaryOperation::LogicalNegation, ScriptType::BoolType, expression->GetStartPosition(), expression->GetLength()); } break; default: break; } this->_scriptData->Diagnostics->LogError(Diagnostics::DiagnosticCode::NoUnaryOperationFound, expression->GetStartPosition(), expression->GetLength()); return new BoundBadExpression(expression->GetStartPosition(), expression->GetLength()); } BoundExpression *Binder::BindFunctionCall(const FunctionCallExpression *expression) { auto func = expression->GetFunction(); if (func->GetKind() == ParsedExpressionKind::Variable) { auto variable = dynamic_cast(func); auto hash = variable->GetValue().GetHash(); if (hash == HashedString::ConstHash("require")) { return this->BindRequire(expression); } else if (hash == HashedString::ConstHash("cast")) { return this->BindCast(expression); } } auto functionExpression = BindExpression(func); auto type = functionExpression->GetType(); if (type->GetClass() != TypeClass::Function) { this->_scriptData->Diagnostics->LogError(Diagnostics::DiagnosticCode::ExpressionIsNotAFunction, expression->GetStartPosition(), expression->GetLength()); return new BoundBadExpression(expression->GetStartPosition(), expression->GetLength()); } auto functionType = std::dynamic_pointer_cast(type); auto givenParameters = expression->GetParameters(); vector boundParameters = vector(givenParameters->size()); for (size_t i = 0; i < givenParameters->size(); i++) { boundParameters[i] = this->BindExpression(givenParameters->at(i)); } auto functionOption = functionType->GetFunctionOption(this->_scriptData->Diagnostics, &boundParameters); if (functionOption == nullptr) { this->_scriptData->Diagnostics->LogError(Diagnostics::DiagnosticCode::InvalidFunctionParameters, expression->GetStartPosition(), expression->GetLength()); return new BoundBadExpression(expression->GetStartPosition(), expression->GetLength()); } return new BoundFunctionCallExpression(functionExpression, boundParameters, functionOption, functionOption->GetReturnType(), expression->GetStartPosition(), expression->GetLength()); } BoundExpression *Binder::BindRequire(const FunctionCallExpression *exp) { auto parameters = exp->GetParameters(); if (parameters->size() != 1) { this->_scriptData->Diagnostics->LogError(Diagnostics::DiagnosticCode::InvalidFunctionParameters, exp->GetStartPosition(), exp->GetLength()); return new BoundBadExpression(exp->GetStartPosition(), exp->GetLength()); } auto parameter = parameters->at(0); auto boundParameter = this->BindExpression(parameter); if (boundParameter->GetKind() != BoundExpressionKind::LiteralString) { this->_scriptData->Diagnostics->LogError(Diagnostics::DiagnosticCode::InvalidFunctionParameters, exp->GetStartPosition(), exp->GetLength()); return new BoundBadExpression(exp->GetStartPosition(), exp->GetLength()); } auto key = *dynamic_cast(boundParameter)->GetValue(); auto opt = this->_scriptData->GetScriptOptions(); auto transformedKey = Utilities::StringUtils::FromUTF8(key); delete boundParameter; if (!opt->DoesModuleExist(transformedKey.c_str(), transformedKey.size())) { this->_scriptData->Diagnostics->LogError(Diagnostics::DiagnosticCode::ModuleDoesntExist, exp->GetStartPosition(), exp->GetLength()); return new BoundBadExpression(exp->GetStartPosition(), exp->GetLength()); } auto module = Script::Clone(opt->ResolveModule(transformedKey.c_str(), transformedKey.size())); if (module->GetReturnType() == nullptr) { for (const auto &v: *module->GetScriptVariables()) { auto type = module->GetVariableType(v.first); auto result = this->_scope->AssignVariable(v.first, type); delete result.GetKey(); } } return new BoundRequireExpression(module, exp->GetStartPosition(), exp->GetLength()); } BoundExpression *Binder::BindCast(const FunctionCallExpression* exp){ auto parameters = exp->GetParameters(); if (parameters->size() != 2) { this->_scriptData->Diagnostics->LogError(Diagnostics::DiagnosticCode::InvalidFunctionParameters, exp->GetStartPosition(), exp->GetLength()); return new BoundBadExpression(exp->GetStartPosition(), exp->GetLength()); } auto toCastParameter = this ->BindExpression(parameters->at(0)); const auto& toCastParameterType = toCastParameter->GetType(); auto destinationTypeParameter = parameters -> at(1); if (destinationTypeParameter ->GetKind() != ParsedExpressionKind::Variable){ this->_scriptData->Diagnostics->LogError(Diagnostics::DiagnosticCode::InvalidFunctionParameters, exp->GetStartPosition(), exp->GetLength()); return new BoundBadExpression(exp->GetStartPosition(), exp->GetLength()); } auto destinationTypeContent = dynamic_cast(destinationTypeParameter)->GetValue(); auto destinationType = ParseTypeIdentifier(&destinationTypeContent); auto castResult = toCastParameterType->CastableTo(destinationType, true); if (castResult == CastResult::InvalidCast){ this->_scriptData->Diagnostics->LogError(Diagnostics::DiagnosticCode::InvalidCast, exp->GetStartPosition(), exp->GetLength()); return new BoundBadExpression(exp->GetStartPosition(), exp->GetLength()); } else if (castResult == CastResult::DataLoss){ this->_scriptData->Diagnostics->LogWarning(Diagnostics::DiagnosticCode::DataLossOnCast, exp->GetStartPosition(), exp->GetLength()); } else if (castResult == CastResult::UncheckedCast){ this->_scriptData->Diagnostics->LogInfo(Diagnostics::DiagnosticCode::UnvalidatedCast, exp->GetStartPosition(), exp->GetLength()); } return new BoundCastExpression(toCastParameter, destinationType); } BoundExpression *Binder::BindIndexExpression(const IndexExpression *expression, bool setter) { auto indexer = this->BindExpression(expression->GetIndexer()); auto index = this->BindExpression(expression->GetIndex()); auto indexerType = indexer->GetType(); if (!indexerType->CanBeIndexedWith(index->GetType().get())) { this->_scriptData->Diagnostics->LogError(Diagnostics::DiagnosticCode::CantIndex, index->GetStartPosition(), index->GetLength()); return new BoundBadExpression(expression->GetStartPosition(), expression->GetLength()); } if (indexerType->GetClass() == TypeClass::UserData) { auto stringKey = dynamic_pointer_cast(index->GetType()); auto field = dynamic_pointer_cast(indexerType)->GetField(stringKey->GetHashValue()); if (!setter) { if (!field->HasGetter()) { this->_scriptData->Diagnostics->LogError(Diagnostics::DiagnosticCode::UserDataFieldNoGetter, index->GetStartPosition(), index->GetLength()); return new BoundBadExpression(expression->GetStartPosition(), expression->GetLength()); } } else { if (!field->HasSetter()) { this->_scriptData->Diagnostics->LogError(Diagnostics::DiagnosticCode::UserDataFieldNoSetter, index->GetStartPosition(), index->GetLength()); return new BoundBadExpression(expression->GetStartPosition(), expression->GetLength()); } } } auto resultType = indexer->GetType()->GetIndexedType(index->GetType().get()); return new BoundIndexExpression(indexer, index, resultType, expression->GetStartPosition(), expression->GetLength()); } BoundExpression *Binder::BindPeriodIndexExpression(const PeriodIndexExpression *expression, bool setter) { auto indexer = this->BindExpression(expression->GetIndexer()); const auto &identifier = expression->GetIndex(); const auto &indexerType = indexer->GetType(); if (!indexerType->CanBeIndexedWithIdentifier(identifier.GetHash())) { this->_scriptData->Diagnostics->LogError(Diagnostics::DiagnosticCode::CantIndex, expression->GetStartPosition(), expression->GetLength()); return new BoundBadExpression(expression->GetStartPosition(), expression->GetLength()); } if (indexerType->GetClass() == TypeClass::UserData) { auto field = dynamic_pointer_cast(indexerType)->GetField(identifier.GetHash()); if (!setter) { if (!field->HasGetter()) { this->_scriptData->Diagnostics->LogError(Diagnostics::DiagnosticCode::UserDataFieldNoGetter, expression->GetStartPosition(), expression->GetLength()); return new BoundBadExpression(expression->GetStartPosition(), expression->GetLength()); } } else { if (!field->HasSetter()) { this->_scriptData->Diagnostics->LogError(Diagnostics::DiagnosticCode::UserDataFieldNoSetter, expression->GetStartPosition(), expression->GetLength()); return new BoundBadExpression(expression->GetStartPosition(), expression->GetLength()); } } } auto resultType = indexer->GetType()->GetIndexedType(identifier.GetHash()); return new BoundPeriodIndexExpression(indexer, identifier, resultType, expression->GetStartPosition(), expression->GetLength()); } BoundExpression *Binder::BindNumericalTableExpression(const ParsedNumericalTableExpression *expression) { auto expressions = expression->GetExpressions(); auto boundExpressions = vector(expressions->size()); shared_ptr valueType = nullptr; if (!boundExpressions.empty()) { boundExpressions[0] = this->BindExpression(expressions->at(0)); valueType = boundExpressions[0]->GetType(); for (size_t i = 1; i < expressions->size(); i++) { boundExpressions[i] = this->BindExpression(expressions->at(i)); if (boundExpressions[i]->GetType()->operator!=(valueType)) { this->_scriptData->Diagnostics->LogError(Diagnostics::DiagnosticCode::InvalidTableValueType, boundExpressions[i]->GetStartPosition(), boundExpressions[i]->GetLength()); } } } if (valueType == nullptr) { valueType = ScriptType::NilType; } auto tableType = std::make_shared(valueType); return new BoundNumericalTableExpression(boundExpressions, tableType, expression->GetStartPosition(), expression->GetLength()); } BoundExpression *Binder::BindTableExpression(const ParsedTableExpression *expression) { auto tableScope = new map(); auto innerScope = new BoundScope(tableScope, nullptr); auto currentScope = this->_scope; this->_scope = innerScope; auto block = dynamic_cast(this->BindBlockStatement(expression->GetBlock())); this->_scope = currentScope; auto tableType = std::make_shared(tableScope, innerScope->GetLocalVariableCount()); delete innerScope; return new BoundTableExpression(block, tableType, expression->GetStartPosition(), expression->GetLength()); } }