#include "Binder.hpp" BoundScriptStatement *Binder::Bind(Script* script, ParsedScriptStatement *s, BoundScope* scriptScope) { auto binder = Binder(); binder._scriptData = script; binder._scope = scriptScope; auto statements = s->GetStatements(); vector boundStatements (statements.size()); for (int i = 0; i < statements.size(); i++){ boundStatements[i] = binder.BindStatement(statements[i]); } return new BoundScriptStatement(boundStatements, scriptScope->GetDeepestScope()); } Binder::~Binder() { delete _scope; } BoundStatement* Binder::BindStatement(ParsedStatement* statement){ switch (statement -> GetKind()) { case ParsedStatementKind ::Script: throw; // This shouldn't happen. case ParsedStatementKind ::Block: return this -> BindBlockStatement(statement); case ParsedStatementKind ::Expression: return this -> BindExpressionStatement(statement); case ParsedStatementKind::Assignment: return this -> BindAssignmentStatement(statement); case ParsedStatementKind ::FunctionDeclaration: return this->BindFunctionDeclarationStatement(statement); case ParsedStatementKind::Bad: return new BoundBadStatement(); } } BoundStatement *Binder::BindBlockStatement(ParsedStatement *statement) { auto statements = ((ParsedBlockStatement*)statement)->GetStatements(); vector boundStatements (statements.size()); this->_scope->GoInnerScope(); for (int i = 0; i < statements.size(); i++){ boundStatements[i] = this -> BindStatement(statements[i]); } this->_scope->GoOuterScope(); return new BoundBlockStatement(boundStatements); } BoundStatement *Binder::BindExpressionStatement(ParsedStatement *statement) { auto exp = ((ParsedExpressionStatement*)statement)->GetExpression(); return new BoundExpressionStatement(this -> BindExpression(exp)); } BoundStatement* Binder::BindAssignmentStatement(ParsedStatement *statement){ auto s = (ParsedAssignmentStatement*) statement; auto boundExpression = this->BindExpression(s->GetExpression()); VariableAssignment assignment = s->IsLocal() ? this->_scope->CreateExplicitLocal(s->GetIdentifier().GetHash(), *boundExpression->GetType()) : this->_scope->AssignVariable(s->GetIdentifier().GetHash(), *boundExpression->GetType()); if (assignment.GetResult() == VariableAssignmentResult::Ok){ auto key = assignment.GetKey(); return new BoundAssignmentStatement(key, boundExpression); } else{ this -> _scriptData -> Diagnostics -> LogError(DiagnosticCode::CantAssignVariable, statement->GetStartPosition(), statement->GetLength()); return new BoundBadStatement(); } } ScriptType* ParseTypeIdentifier(HashedString s){ switch (s.GetHash()){ case HashedString::ConstHash("number"): return new NumericScriptType(false, false); case HashedString::ConstHash("bool"): return new ScriptType(TypeClass::Bool); case HashedString::ConstHash("string"): return new ScriptType(TypeClass::String); default: return new ScriptType(TypeClass::Error); // todo: change to userdata } } BoundStatement *Binder::BindFunctionDeclarationStatement(ParsedStatement *statement) { auto functionStatement = (ParsedFunctionDeclarationStatement*) statement; auto parameters = functionStatement->GetParameters(); vector parameterTypes = vector(parameters.size()); vector parameterKeys = vector(parameters.size()); this->_scope->GoInnerScope(); auto scopeId = this->_scope->GetCurrentScope(); for (int i = 0; i < parameters.size(); i++){ auto var = parameters[i]; auto parsedType = ParseTypeIdentifier(var->GetType()); parameterTypes[i] = parsedType; parameterKeys[i] = var->GetIdentifier().GetHash(); this->_scope->CreateExplicitLocal(var->GetIdentifier().GetHash(), *parsedType); } auto boundBlock = this -> BindBlockStatement(functionStatement->GetBlock()); this->_scope->GoOuterScope(); auto identifier = functionStatement->GetIdentifier(); auto returnType = new ScriptType(TypeClass::Nil); auto type = new FunctionScriptType(returnType, parameterTypes, parameterKeys, scopeId); auto assignment = this->_scope->AssignVariable(identifier.GetHash(), *type); if (assignment.GetResult() == VariableAssignmentResult::Ok){ return new BoundFunctionDeclarationStatement(type, assignment.GetKey(), (BoundBlockStatement*)boundBlock); } return new BoundBadStatement(); } BoundExpression* Binder::BindExpression(ParsedExpression* expression){ switch (expression -> GetKind()){ case ParsedExpressionKind ::LiteralInteger: return new BoundLiteralIntegerExpression(((LiteralIntegerExpression*)expression)->GetValue(), expression->GetStartPosition(), expression->GetLength()); case ParsedExpressionKind ::LiteralFloat: return new BoundLiteralFloatExpression(((LiteralFloatExpression*)expression)->GetValue(), expression->GetStartPosition(), expression->GetLength()); case ParsedExpressionKind ::LiteralString: return new BoundLiteralStringExpression(((LiteralStringExpression*)expression)->GetValue(), expression->GetStartPosition(), expression->GetLength()); case ParsedExpressionKind ::LiteralBool: return new BoundLiteralBoolExpression(((LiteralBoolExpression*)expression)->GetValue(), expression->GetStartPosition(), expression->GetLength()); case ParsedExpressionKind ::Variable: return this -> BindVariableExpression((VariableExpression*)expression); case ParsedExpressionKind ::Binary: return this -> BindBinaryOperator((BinaryExpression*)expression); case ParsedExpressionKind ::Unary: return this -> BindUnaryOperator((UnaryExpression*)expression); case ParsedExpressionKind ::Parenthesized: return BindExpression(((ParenthesizedExpression*)expression)->GetInnerExpression()); case ParsedExpressionKind ::Bad: return new BoundBadExpression(expression->GetStartPosition(), expression-> GetLength()); } } BoundExpression* Binder::BindVariableExpression(VariableExpression* expression){ auto key = expression->GetValue(); auto scope = this->_scope->Exists(key.GetHash()); if (scope == -1){ this -> _scriptData -> Diagnostics->LogError(DiagnosticCode::VariableNotFound, expression->GetStartPosition(), expression->GetLength()); return new BoundBadExpression(expression->GetStartPosition(), expression->GetLength()); } auto var = this->_scope->GetVariable(scope, key.GetHash()); auto type = var->GetType(); return new BoundVariableExpression(scope, key.GetHash(), type, expression->GetStartPosition(), expression->GetLength()); } BoundExpression* Binder::BindBinaryOperator(BinaryExpression* expression){ auto boundLeft = this -> BindExpression(expression->GetLeft()); auto boundRight = this -> BindExpression(expression->GetRight()); auto boundLeftType = boundLeft->GetType(); auto boundRightType = boundRight->GetType(); switch (expression->GetOperatorKind()){ case BinaryOperatorKind ::Addition: if (boundLeftType->GetClass() == TypeClass::Number && boundRightType->GetClass() == TypeClass::Number){ auto leftNumeric = (NumericScriptType*)boundLeftType; auto rightNumeric = (NumericScriptType*)boundRightType; if (leftNumeric->IsAwareOfFloat() && rightNumeric->IsAwareOfFloat()){ return new BoundBinaryExpression(boundLeft, boundRight, BoundBinaryOperation::Addition, new NumericScriptType(true, leftNumeric->IsFloat() || rightNumeric->IsFloat()), expression->GetStartPosition(), expression->GetLength()); } else{ return new BoundBinaryExpression(boundLeft, boundRight, BoundBinaryOperation::Addition, new NumericScriptType(false, false), expression->GetStartPosition(), expression->GetLength()); } } else if (boundLeftType->GetClass() == TypeClass::String){ return new BoundBinaryExpression(boundLeft, boundRight, BoundBinaryOperation::Concatenation, new ScriptType(TypeClass::String), expression->GetStartPosition(), expression->GetLength()); } break; case BinaryOperatorKind ::Subtraction: if (boundLeftType->GetClass() == TypeClass::Number && boundRightType->GetClass() == TypeClass::Number){ auto leftNumeric = (NumericScriptType*)boundLeftType; auto rightNumeric = (NumericScriptType*)boundRightType; if (leftNumeric->IsAwareOfFloat() && rightNumeric->IsAwareOfFloat()){ return new BoundBinaryExpression(boundLeft, boundRight, BoundBinaryOperation::Subtraction, new NumericScriptType(true, leftNumeric->IsFloat() || rightNumeric->IsFloat()), expression->GetStartPosition(), expression->GetLength()); } else{ return new BoundBinaryExpression(boundLeft, boundRight, BoundBinaryOperation::Subtraction, new NumericScriptType(false, false), expression->GetStartPosition(), expression->GetLength()); } } break; case BinaryOperatorKind ::Multiplication: if (boundLeftType->GetClass() == TypeClass::Number && boundRightType->GetClass() == TypeClass::Number){ auto leftNumeric = (NumericScriptType*)boundLeftType; auto rightNumeric = (NumericScriptType*)boundRightType; if (leftNumeric->IsAwareOfFloat() && rightNumeric->IsAwareOfFloat()){ return new BoundBinaryExpression(boundLeft, boundRight, BoundBinaryOperation::Multiplication, new NumericScriptType(true, leftNumeric->IsFloat() || rightNumeric->IsFloat()), expression->GetStartPosition(), expression->GetLength()); } else{ return new BoundBinaryExpression(boundLeft, boundRight, BoundBinaryOperation::Multiplication, new NumericScriptType(false, false), expression->GetStartPosition(), expression->GetLength()); } } break; case BinaryOperatorKind ::Division: if (boundLeftType->GetClass() == TypeClass::Number && boundRightType->GetClass() == TypeClass::Number){ auto leftNumeric = (NumericScriptType*)boundLeftType; auto rightNumeric = (NumericScriptType*)boundRightType; if (leftNumeric->IsAwareOfFloat() && rightNumeric->IsAwareOfFloat()){ return new BoundBinaryExpression(boundLeft, boundRight, BoundBinaryOperation::Division, new NumericScriptType(true, leftNumeric->IsFloat() || rightNumeric->IsFloat()), expression->GetStartPosition(), expression->GetLength()); } else{ return new BoundBinaryExpression(boundLeft, boundRight, BoundBinaryOperation::Division, new NumericScriptType(false, false), expression->GetStartPosition(), expression->GetLength()); } } break; case BinaryOperatorKind ::Equality: return new BoundBinaryExpression(boundLeft, boundRight, BoundBinaryOperation::Equality, new ScriptType(TypeClass::Bool), expression->GetStartPosition(), expression->GetLength()); case BinaryOperatorKind ::Inequality: return new BoundBinaryExpression(boundLeft, boundRight, BoundBinaryOperation::Inequality, new ScriptType(TypeClass::Bool), expression->GetStartPosition(), expression->GetLength()); case BinaryOperatorKind ::LogicalAnd: if (boundLeftType->GetClass() == TypeClass::Bool && boundRightType->GetClass() == TypeClass::Bool) return new BoundBinaryExpression(boundLeft, boundRight, BoundBinaryOperation::LogicalAnd, new ScriptType(TypeClass::Bool), expression->GetStartPosition(), expression->GetLength()); break; case BinaryOperatorKind ::LogicalOr: if (boundLeftType->GetClass() == TypeClass::Bool && boundRightType->GetClass() == TypeClass::Bool) return new BoundBinaryExpression(boundLeft, boundRight, BoundBinaryOperation::LogicalOr, new ScriptType(TypeClass::Bool), expression->GetStartPosition(), expression->GetLength()); break; } this -> _scriptData -> Diagnostics->LogError(DiagnosticCode::NoBinaryOperationFound, expression->GetStartPosition(), expression->GetLength()); return new BoundBadExpression(expression->GetStartPosition(), expression->GetLength()); } BoundExpression* Binder::BindUnaryOperator(UnaryExpression* expression){ auto operand = this -> BindExpression(expression->GetOperand()); auto operandType = operand -> GetType(); switch (expression->GetOperatorKind()){ case UnaryOperatorKind ::Identity: if (operandType->GetClass() == TypeClass::Number){ // Identity won't change anything during evaluation, so just return the inner operand. return operand; } break; case UnaryOperatorKind ::Negation: if (operandType->GetClass() == TypeClass::Number){ auto innerType = (NumericScriptType*)operandType; return new BoundUnaryExpression(operand, BoundUnaryOperation::Negation, new NumericScriptType(innerType->IsAwareOfFloat(), innerType->IsFloat()), expression->GetStartPosition(), expression->GetLength()); } break; case UnaryOperatorKind ::LogicalNegation: if (operandType->GetClass() == TypeClass::Bool){ return new BoundUnaryExpression(operand, BoundUnaryOperation::LogicalNegation, new ScriptType(TypeClass::Bool), expression->GetStartPosition(), expression->GetLength()); } break; default: break; } this -> _scriptData -> Diagnostics->LogError(DiagnosticCode::NoUnaryOperationFound, expression->GetStartPosition(), expression->GetLength()); return new BoundBadExpression(expression->GetStartPosition(), expression->GetLength()); }