diff --git a/Upsilon/Binder/Binder.cs b/Upsilon/Binder/Binder.cs index f778a26..d7587ed 100644 --- a/Upsilon/Binder/Binder.cs +++ b/Upsilon/Binder/Binder.cs @@ -13,8 +13,8 @@ namespace Upsilon.Binder private readonly Diagnostics _diagnostics; public BoundScope Scope { get; private set; } - private Dictionary _unboundFunctions = - new Dictionary(); + private Dictionary _unboundFunctions = + new Dictionary(); public Binder(Diagnostics diagnostics, Dictionary variables) { @@ -37,7 +37,7 @@ namespace Upsilon.Binder Scope = Scope.ParentScope; unboundFunctionStatement.Key.IsBound = true; } - _unboundFunctions = new Dictionary(); + _unboundFunctions = new Dictionary(); return new BoundScript((BoundBlockStatement) bound); } @@ -53,10 +53,10 @@ namespace Upsilon.Binder return BindBlockStatement((BlockStatementSyntax) s); case SyntaxKind.IfStatement: return BindIfStatement((IfStatementSyntax) s); - case SyntaxKind.FunctionStatement: - return BindFunctionStatement((FunctionStatementSyntax) s); case SyntaxKind.ReturnStatement: return BindReturnStatement((ReturnStatementSyntax) s); + case SyntaxKind.FunctionAssignmentStatement: + return BindFunctionAssignmentStatement((FunctionAssignmentStatementSyntax) s); } throw new NotImplementedException(s.Kind.ToString()); @@ -82,6 +82,8 @@ namespace Upsilon.Binder return BindTableExpression((TableExpressionSyntax) e); case SyntaxKind.IndexExpression: return BindIndexExpression((IndexExpressionSyntax) e); + case SyntaxKind.FunctionExpression: + return BindFunctionExpression((FunctionExpressionSyntax) e); case SyntaxKind.BadExpression: break; case SyntaxKind.ScriptUnit: @@ -305,14 +307,44 @@ namespace Upsilon.Binder } } - private BoundStatement BindFunctionStatement(FunctionStatementSyntax e) + private BoundExpression BindFunctionExpression(FunctionExpressionSyntax e) { + var innerScope = new BoundScope(Scope); + var parameters = ImmutableArray.CreateBuilder(); + foreach (var identifierToken in e.Parameters) + { + var vari = new VariableSymbol(identifierToken.Name, Type.Unknown, true); + parameters.Add(vari); + innerScope.SetVariable(vari); + } + + if (parameters.Count == 0) + { + Scope = innerScope; + var block = BindBlockStatement(e.Block); + Scope = Scope.ParentScope; + var func = new BoundFunctionExpression(parameters.ToImmutable(), (BoundBlockStatement) block); + return func; + } + else + { + var unbound = new UnboundFunctionExpression(parameters.ToImmutable(), e.Block); + _unboundFunctions.Add( + new FunctionVariableSymbol(Guid.NewGuid().ToString(), Type.Unknown, true, parameters.ToImmutable()), + unbound); + return unbound; + } + } + + private BoundStatement BindFunctionAssignmentStatement(FunctionAssignmentStatementSyntax e) + { + var func = (BoundFunctionExpression)BindFunctionExpression(e.FunctionExpression); var name = e.Identifier.Name; var isLocal = e.LocalToken != null; var innerScope = new BoundScope(Scope); var parameters = ImmutableArray.CreateBuilder(); - foreach (var identifierToken in e.Parameters) + foreach (var identifierToken in func.Parameters) { var vari = new VariableSymbol(identifierToken.Name, Type.Unknown, true); parameters.Add(vari); @@ -339,26 +371,14 @@ namespace Upsilon.Binder else { _diagnostics.LogCannotConvert(Type.Function, variable.Type, e.Span); - return new BoundExpressionStatement(null); + return new BoundExpressionStatement(new BoundLiteralExpression(new LuaNull())); } } } - if (parameters.Count == 0) - { - Scope = innerScope; - var block = BindBlockStatement(e.Block); - Scope = Scope.ParentScope; - ((FunctionVariableSymbol) variable).IsBound = true; - var func = new BoundFunctionStatement(variable, parameters.ToImmutable(), (BoundBlockStatement) block); - return func; - } - else - { - var unbound = new UnboundFunctionStatement(variable, parameters.ToImmutable(), e.Block); - _unboundFunctions.Add((FunctionVariableSymbol) variable, unbound); - return unbound; - } + ((FunctionVariableSymbol) variable).IsBound = true; + + return new BoundFunctionAssignmentStatement(variable, func); } private BoundStatement BindReturnStatement(ReturnStatementSyntax e) diff --git a/Upsilon/Binder/BoundExpressions/BoundFunctionCallExpression.cs b/Upsilon/Binder/BoundExpressions/BoundFunctionCallExpression.cs index 8949f48..b3f9824 100644 --- a/Upsilon/Binder/BoundExpressions/BoundFunctionCallExpression.cs +++ b/Upsilon/Binder/BoundExpressions/BoundFunctionCallExpression.cs @@ -15,6 +15,6 @@ namespace Upsilon.Binder } public override BoundKind Kind => BoundKind.BoundFunctionCallExpression; - public override Type Type => Type.Nil; + public override Type Type => Type.Unknown; } } \ No newline at end of file diff --git a/Upsilon/Binder/BoundKind.cs b/Upsilon/Binder/BoundKind.cs index 0ae0d1c..6865d82 100644 --- a/Upsilon/Binder/BoundKind.cs +++ b/Upsilon/Binder/BoundKind.cs @@ -18,8 +18,9 @@ namespace Upsilon.Binder BoundBlockStatement, BoundIfStatement, BoundElseStatement, - BoundFunctionStatement, + BoundFunctionExpression, BoundPromise, BoundReturnStatement, + BoundFunctionAssignmentStatement } } \ No newline at end of file diff --git a/Upsilon/Binder/BoundStatements/BoundFunctionExpression.cs b/Upsilon/Binder/BoundStatements/BoundFunctionExpression.cs new file mode 100644 index 0000000..6615923 --- /dev/null +++ b/Upsilon/Binder/BoundStatements/BoundFunctionExpression.cs @@ -0,0 +1,34 @@ +using System.Collections.Immutable; +using Upsilon.BaseTypes; + +namespace Upsilon.Binder +{ + public class BoundFunctionExpression : BoundExpression + { + public ImmutableArray Parameters { get; } + public BoundBlockStatement Block { get; set; } + + public BoundFunctionExpression(ImmutableArray parameters, BoundBlockStatement block) + { + Parameters = parameters; + Block = block; + } + + public override BoundKind Kind => BoundKind.BoundFunctionExpression; + public override Type Type => Type.Function; + } + + public class BoundFunctionAssignmentStatement : BoundStatement + { + public VariableSymbol Variable { get; } + public BoundFunctionExpression Func { get; } + + public BoundFunctionAssignmentStatement(VariableSymbol variable, BoundFunctionExpression func) + { + Variable = variable; + Func = func; + } + + public override BoundKind Kind => BoundKind.BoundFunctionAssignmentStatement; + } +} \ No newline at end of file diff --git a/Upsilon/Binder/BoundStatements/BoundFunctionStatement.cs b/Upsilon/Binder/BoundStatements/BoundFunctionStatement.cs deleted file mode 100644 index 27529ea..0000000 --- a/Upsilon/Binder/BoundStatements/BoundFunctionStatement.cs +++ /dev/null @@ -1,21 +0,0 @@ -using System.Collections.Immutable; - -namespace Upsilon.Binder -{ - public class BoundFunctionStatement : BoundStatement - { - public VariableSymbol Identifier { get; } - public ImmutableArray Parameters { get; } - public BoundBlockStatement Block { get; set; } - - public BoundFunctionStatement(VariableSymbol identifier, ImmutableArray parameters, - BoundBlockStatement block) - { - Identifier = identifier; - Parameters = parameters; - Block = block; - } - - public override BoundKind Kind => BoundKind.BoundFunctionStatement; - } -} \ No newline at end of file diff --git a/Upsilon/Binder/BoundStatements/UnboundFunctionStatement.cs b/Upsilon/Binder/BoundStatements/UnboundFunctionExpression.cs similarity index 52% rename from Upsilon/Binder/BoundStatements/UnboundFunctionStatement.cs rename to Upsilon/Binder/BoundStatements/UnboundFunctionExpression.cs index 46bca44..89ee80b 100644 --- a/Upsilon/Binder/BoundStatements/UnboundFunctionStatement.cs +++ b/Upsilon/Binder/BoundStatements/UnboundFunctionExpression.cs @@ -3,10 +3,10 @@ using Upsilon.Parser; namespace Upsilon.Binder { - public class UnboundFunctionStatement : BoundFunctionStatement + public class UnboundFunctionExpression : BoundFunctionExpression { - public UnboundFunctionStatement(VariableSymbol identifier, ImmutableArray parameters, - BlockStatementSyntax unboundBlock) : base(identifier, parameters, null) + public UnboundFunctionExpression(ImmutableArray parameters, + BlockStatementSyntax unboundBlock) : base(parameters, null) { UnboundBlock = unboundBlock; } diff --git a/Upsilon/Evaluator/Evaluator.cs b/Upsilon/Evaluator/Evaluator.cs index 8a6a5dd..284e3d3 100644 --- a/Upsilon/Evaluator/Evaluator.cs +++ b/Upsilon/Evaluator/Evaluator.cs @@ -77,7 +77,9 @@ namespace Upsilon.Evaluator case BoundKind.BoundUnaryExpression: case BoundKind.VariableExpression: case BoundKind.BoundFunctionCallExpression: + case BoundKind.BoundFunctionExpression: case BoundKind.BoundTableExpression: + case BoundKind.BoundIndexExpression: _lastValue = EvaluateExpression((BoundExpression) b); break; case BoundKind.BoundAssignmentStatement: @@ -85,8 +87,9 @@ namespace Upsilon.Evaluator case BoundKind.BoundBlockStatement: case BoundKind.BoundIfStatement: case BoundKind.BoundElseStatement: - case BoundKind.BoundFunctionStatement: + case BoundKind.BoundFunctionAssignmentStatement: case BoundKind.BoundPromise: + case BoundKind.BoundReturnStatement: EvaluateStatement((BoundStatement) b); break; default: @@ -110,15 +113,12 @@ namespace Upsilon.Evaluator case BoundKind.BoundIfStatement: EvaluateBoundIfStatement((BoundIfStatement) e); break; - case BoundKind.BoundFunctionStatement: - EvaluateBoundFunctionStatement((BoundFunctionStatement) e); - break; - case BoundKind.BoundPromise: - EvaluateUnboundFunctionStatement((UnboundFunctionStatement) e); - break; case BoundKind.BoundReturnStatement: EvaluateReturnStatement((BoundReturnStatement) e); break; + case BoundKind.BoundFunctionAssignmentStatement: + EvaluateBoundFunctionAssigmentStatement((BoundFunctionAssignmentStatement) e); + break; default: EvaluateExpressionStatement((BoundExpressionStatement) e); break; @@ -149,6 +149,11 @@ namespace Upsilon.Evaluator return EvaluateTableExpression((BoundTableExpression) e); case BoundKind.BoundIndexExpression: return EvaluateIndexExpression((BoundIndexExpression) e); + case BoundKind.BoundFunctionExpression: + return EvaluateBoundFunctionStatement((BoundFunctionExpression) e); + case BoundKind.BoundPromise: + return EvaluateUnboundFunctionStatement((UnboundFunctionExpression) e); + break; default: throw new NotImplementedException(); } @@ -254,26 +259,25 @@ namespace Upsilon.Evaluator _lastValue = innerEvaluator._lastValue; } - private void EvaluateBoundFunctionStatement(BoundFunctionStatement boundFunctionStatement) + private void EvaluateBoundFunctionAssigmentStatement(BoundFunctionAssignmentStatement e) { - var func = new LuaFunction(boundFunctionStatement.Parameters, boundFunctionStatement.Block); - if (boundFunctionStatement.Identifier.Local) - Scope.Set(boundFunctionStatement.Identifier, func); + var func = EvaluateBoundFunctionStatement(e.Func); + if (e.Variable.Local) + Scope.Set(e.Variable, func); else - { - Scope.SetGlobal(boundFunctionStatement.Identifier, func); - _lastValue = func; - } + Scope.SetGlobal(e.Variable, func); } - private void EvaluateUnboundFunctionStatement(UnboundFunctionStatement unboundFunctionStatement) + private LuaType EvaluateBoundFunctionStatement(BoundFunctionExpression boundFunctionExpression) { - var func = new LuaFunction(unboundFunctionStatement.Parameters, unboundFunctionStatement.Block); - if (unboundFunctionStatement.Identifier.Local) - Scope.Set(unboundFunctionStatement.Identifier, func); - else - Scope.SetGlobal(unboundFunctionStatement.Identifier, func); + var func = new LuaFunction(boundFunctionExpression.Parameters, boundFunctionExpression.Block); + return func; + } + private LuaType EvaluateUnboundFunctionStatement(UnboundFunctionExpression unboundFunctionExpression) + { + var func = new LuaFunction(unboundFunctionExpression.Parameters, unboundFunctionExpression.Block); + return func; } private LuaType EvaluateBoundFunctionCallExpression(BoundFunctionCallExpression boundFunctionCallExpression) @@ -281,7 +285,7 @@ namespace Upsilon.Evaluator var variable = EvaluateExpression(boundFunctionCallExpression.Identifier); if (!(variable is LuaFunction function)) { - throw new Exception("Variable is not a function."); + throw new Exception($"Variable is not a function."); } var innerEvaluator = new Evaluator(_diagnostics, Scope); @@ -316,12 +320,19 @@ namespace Upsilon.Evaluator var value = EvaluateExpression(assignment.BoundExpression); dic.Add(key, value); } - else if (boundStatement.Kind == BoundKind.BoundFunctionStatement) + else if (boundStatement.Kind == BoundKind.BoundFunctionExpression) { - var function = (BoundFunctionStatement) boundStatement; - var key = function.Identifier; + var expressionStatement = (BoundExpressionStatement)boundStatement; + var function = (BoundFunctionExpression) expressionStatement.Expression; var func = new LuaFunction(function.Parameters, function.Block); - dic.Add(key, func); + dic.Add(new VariableSymbol(currentPos.ToString(), func.Type, false), func); + } + else if (boundStatement.Kind == BoundKind.BoundFunctionAssignmentStatement) + { + var assignment = (BoundFunctionAssignmentStatement)boundStatement; + var key = assignment.Variable; + var value = EvaluateExpression(assignment.Func); + dic.Add(key, value); } else { diff --git a/Upsilon/Parser/StatementSyntax/FunctionStatementSyntax.cs b/Upsilon/Parser/ExpressionSyntax/FunctionExpressionSyntax.cs similarity index 71% rename from Upsilon/Parser/StatementSyntax/FunctionStatementSyntax.cs rename to Upsilon/Parser/ExpressionSyntax/FunctionExpressionSyntax.cs index 1fdd965..5a4ca0f 100644 --- a/Upsilon/Parser/StatementSyntax/FunctionStatementSyntax.cs +++ b/Upsilon/Parser/ExpressionSyntax/FunctionExpressionSyntax.cs @@ -3,24 +3,20 @@ using System.Collections.Immutable; namespace Upsilon.Parser { - public class FunctionStatementSyntax : StatementSyntax + public class FunctionExpressionSyntax : ExpressionSyntax { - public SyntaxToken LocalToken { get; } public SyntaxToken FunctionToken { get; } - public IdentifierToken Identifier { get; } public SyntaxToken OpenParenthesis { get; } public ImmutableArray Parameters { get; } public SyntaxToken CloseParenthesis { get; } public BlockStatementSyntax Block { get; } public SyntaxToken EndToken { get; } - public FunctionStatementSyntax(SyntaxToken localToken, SyntaxToken functionToken, IdentifierToken identifier, + public FunctionExpressionSyntax(SyntaxToken functionToken, SyntaxToken openParenthesis, ImmutableArray parameters, SyntaxToken closeParenthesis, BlockStatementSyntax block, SyntaxToken endToken) { - LocalToken = localToken; FunctionToken = functionToken; - Identifier = identifier; OpenParenthesis = openParenthesis; Parameters = parameters; CloseParenthesis = closeParenthesis; @@ -28,12 +24,10 @@ namespace Upsilon.Parser EndToken = endToken; } - public override SyntaxKind Kind => SyntaxKind.FunctionStatement; + public override SyntaxKind Kind => SyntaxKind.FunctionExpression; public override IEnumerable ChildNodes() { - yield return LocalToken; yield return FunctionToken; - yield return Identifier; yield return OpenParenthesis; foreach (var identifierToken in Parameters) { diff --git a/Upsilon/Parser/Parser.cs b/Upsilon/Parser/Parser.cs index c57d679..7d080e6 100644 --- a/Upsilon/Parser/Parser.cs +++ b/Upsilon/Parser/Parser.cs @@ -70,19 +70,19 @@ namespace Upsilon.Parser { return ParseIfStatement(SyntaxKind.IfKeyword); } - - if (Current.Kind == SyntaxKind.FunctionKeyword) - { - return ParseFunctionStatement(); - } - if (Current.Kind == SyntaxKind.LocalKeyword && Next.Kind == SyntaxKind.FunctionKeyword) - { - return ParseFunctionStatement(); - } if (Current.Kind == SyntaxKind.ReturnKeyword) { return ParseReturnStatement(); } + if (Current.Kind == SyntaxKind.FunctionKeyword && Next.Kind != SyntaxKind.OpenParenthesis) + { + return ParseFunctionAssignmentStatement(); + } + if (Current.Kind == SyntaxKind.LocalKeyword && Next.Kind == SyntaxKind.FunctionKeyword) + { + return ParseFunctionAssignmentStatement(); + } + return ParseExpressionStatement(); } @@ -127,15 +127,9 @@ namespace Upsilon.Parser } } - private StatementSyntax ParseFunctionStatement() + private ExpressionSyntax ParseFunctionExpression() { - SyntaxToken localToken = null; - if (Current.Kind == SyntaxKind.LocalKeyword) - { - localToken = NextToken(); - } var functionToken = MatchToken(SyntaxKind.FunctionKeyword); - var identifier = (IdentifierToken)MatchToken(SyntaxKind.Identifier); var openParenthesis = MatchToken(SyntaxKind.OpenParenthesis); var variableBuilder = ImmutableArray.CreateBuilder(); while (Current.Kind != SyntaxKind.CloseParenthesis) @@ -146,12 +140,38 @@ namespace Upsilon.Parser NextToken(); } var closeParenthesis = MatchToken(SyntaxKind.CloseParenthesis); - var block = ParseBlockStatement(new[] {SyntaxKind.EndKeyword}); - var endToken = MatchToken(SyntaxKind.EndKeyword); - return new FunctionStatementSyntax(localToken, functionToken, identifier, openParenthesis, + var block = ParseBlockStatement(new[] {SyntaxKind.EndKeyword}); + var endToken = MatchToken(SyntaxKind.EndKeyword); + return new FunctionExpressionSyntax(functionToken, openParenthesis, variableBuilder.ToImmutable(), closeParenthesis, (BlockStatementSyntax) block, endToken); } + private StatementSyntax ParseFunctionAssignmentStatement() + { + SyntaxToken localToken = null; + if (Current.Kind == SyntaxKind.LocalKeyword) + { + localToken = NextToken(); + } + var functionToken = MatchToken(SyntaxKind.FunctionKeyword); + var identifier = (IdentifierToken)MatchToken(SyntaxKind.Identifier); + var openParenthesis = MatchToken(SyntaxKind.OpenParenthesis); + var variableBuilder = ImmutableArray.CreateBuilder(); + while (Current.Kind != SyntaxKind.CloseParenthesis) + { + var variableIdentifier = (IdentifierToken)MatchToken(SyntaxKind.Identifier); + variableBuilder.Add(variableIdentifier); + if (Current.Kind == SyntaxKind.Comma) + NextToken(); + } + var closeParenthesis = MatchToken(SyntaxKind.CloseParenthesis); + var block = ParseBlockStatement(new[] {SyntaxKind.EndKeyword}); + var endToken = MatchToken(SyntaxKind.EndKeyword); + var functionExpression = new FunctionExpressionSyntax(functionToken, openParenthesis, + variableBuilder.ToImmutable(), closeParenthesis, (BlockStatementSyntax) block, endToken); + return new FunctionAssignmentStatementSyntax(localToken, identifier, functionExpression); + } + private ExpressionStatementSyntax ParseExpressionStatement() { var expression = ParseExpression(); @@ -167,7 +187,15 @@ namespace Upsilon.Parser private ExpressionSyntax ParseExpression() { - var expression = ParseBinaryExpression(); + ExpressionSyntax expression; + if (Current.Kind == SyntaxKind.FunctionKeyword && Next.Kind == SyntaxKind.OpenParenthesis) + { + expression = ParseFunctionExpression(); + } + else + { + expression = ParseBinaryExpression(); + } while (Current.Kind == SyntaxKind.OpenBracket || Current.Kind == SyntaxKind.OpenParenthesis) { if (Current.Kind == SyntaxKind.OpenBracket) diff --git a/Upsilon/Parser/StatementSyntax/FunctionAssignmentStatementSyntax.cs b/Upsilon/Parser/StatementSyntax/FunctionAssignmentStatementSyntax.cs new file mode 100644 index 0000000..8784a14 --- /dev/null +++ b/Upsilon/Parser/StatementSyntax/FunctionAssignmentStatementSyntax.cs @@ -0,0 +1,27 @@ +using System.Collections.Generic; + +namespace Upsilon.Parser +{ + public class FunctionAssignmentStatementSyntax : StatementSyntax + { + public FunctionAssignmentStatementSyntax(SyntaxToken localToken, IdentifierToken identifier, FunctionExpressionSyntax functionExpression) + { + LocalToken = localToken; + Identifier = identifier; + FunctionExpression = functionExpression; + } + + public SyntaxToken LocalToken { get; } + public IdentifierToken Identifier { get; } + public FunctionExpressionSyntax FunctionExpression { get; } + + public override SyntaxKind Kind => SyntaxKind.FunctionAssignmentStatement; + public override IEnumerable ChildNodes() + { + if (LocalToken != null) + yield return LocalToken; + yield return Identifier; + yield return FunctionExpression; + } + } +} \ No newline at end of file diff --git a/Upsilon/Parser/SyntaxKind.cs b/Upsilon/Parser/SyntaxKind.cs index 48ed33d..0dc78c7 100644 --- a/Upsilon/Parser/SyntaxKind.cs +++ b/Upsilon/Parser/SyntaxKind.cs @@ -65,7 +65,8 @@ namespace Upsilon.Parser IfStatement, ElseIfStatement, ElseStatement, - FunctionStatement, + FunctionExpression, ReturnStatement, + FunctionAssignmentStatement } } \ No newline at end of file diff --git a/UpsilonTests/TableTests.cs b/UpsilonTests/TableTests.cs index 8a43cad..cddcc8d 100644 --- a/UpsilonTests/TableTests.cs +++ b/UpsilonTests/TableTests.cs @@ -1,3 +1,4 @@ +using System; using Upsilon.Evaluator; using Xunit; @@ -104,9 +105,9 @@ return table[""test""]() const string input = @" table = { function func() - return function func() + return function() return { - function func() + function() return 100 end }