diff --git a/Upsilon/BaseTypes/UserData/UserData.cs b/Upsilon/BaseTypes/UserData/UserData.cs index 112b379..e2300ec 100644 --- a/Upsilon/BaseTypes/UserData/UserData.cs +++ b/Upsilon/BaseTypes/UserData/UserData.cs @@ -43,5 +43,10 @@ namespace Upsilon.BaseTypes.UserData diagnostics.LogError($"Cannot find member '{s}' on type '{Value.GetType()}'", span); } } + + public (LuaType Type, bool Failed) BinaryOperator(LuaType par1, OperatorType op, LuaType par2) + { + return _typeInfo.BinaryOperator(Value, par1, op, par2); + } } } \ No newline at end of file diff --git a/Upsilon/BaseTypes/UserData/UserDataType.cs b/Upsilon/BaseTypes/UserData/UserDataType.cs index ea5fd2e..02e25cf 100644 --- a/Upsilon/BaseTypes/UserData/UserDataType.cs +++ b/Upsilon/BaseTypes/UserData/UserDataType.cs @@ -24,14 +24,14 @@ namespace Upsilon.BaseTypes.UserData Methods.Add(commonName, new UserDataMethod(methodInfo)); } } - _operatorHandler = new UserDataTypeOperators(type); + OperatorHandler = new UserDataTypeOperators(type); } private System.Type Type { get; } private Dictionary Variables { get; } private Dictionary Properties { get; } private Dictionary Methods { get; } - private UserDataTypeOperators _operatorHandler { get; } + private UserDataTypeOperators OperatorHandler { get; } public (LuaType Type, bool Failed) Get(object value, string member) { @@ -72,9 +72,15 @@ namespace Upsilon.BaseTypes.UserData return true; } - public (LuaType Type, bool Failed) BinaryOperator(object value, OperatorType op, object value2) + public (LuaType Type, bool Failed) BinaryOperator(object value, LuaType par1, OperatorType op, LuaType par2) { - return (null, true); + var method = OperatorHandler.GetBinaryOperator(op, par1.GetCSharpType(), par2.GetCSharpType()); + if (method == null) + { + return (new LuaNull(), true); + } + + return (method.Invoke(value, new[] {par1.ToCSharpObject(), par2.ToCSharpObject()}).ToLuaType(), false); } } } \ No newline at end of file diff --git a/Upsilon/BaseTypes/UserData/UserDataTypeOperators.cs b/Upsilon/BaseTypes/UserData/UserDataTypeOperators.cs index f37aace..26b5251 100644 --- a/Upsilon/BaseTypes/UserData/UserDataTypeOperators.cs +++ b/Upsilon/BaseTypes/UserData/UserDataTypeOperators.cs @@ -1,24 +1,118 @@ using System; +using System.Collections.Generic; +using System.Linq; using System.Reflection; namespace Upsilon.BaseTypes.UserData { public enum OperatorType { + UnaryPlus, + UnaryNegation, + LogicalNot, Addition, + Subtraction, + Multiplication, + Division, } public class UserDataTypeOperators { - private class OperatorKeyData + private class OperatorMethod { + public MethodInfo Info { get; } + public int[] ParameterTypeHashes { get; } + public OperatorMethod(MethodInfo info) + { + Info = info; + ParameterTypeHashes = info.GetParameters().Select(x => x.ParameterType.GetHashCode()).ToArray(); + } } - public UserDataTypeOperators(System.Type t) + private readonly Dictionary> _operatorMethods; + + public UserDataTypeOperators(IReflect t) { - var additionMethod = t.GetMethod("op_Addition", BindingFlags.Static | BindingFlags.Public); - Console.WriteLine(additionMethod); + var staticMethod = t.GetMethods(BindingFlags.Static | BindingFlags.Public); + _operatorMethods = new Dictionary>(); + foreach (var methodInfo in staticMethod) + { + switch (methodInfo.Name) + { + // Unary + case "op_UnaryPlus": + LoadMethod(OperatorType.UnaryPlus, methodInfo); + break; + case "op_UnaryNegation": + LoadMethod(OperatorType.UnaryNegation, methodInfo); + break; + case "op_LogicalNot": + LoadMethod(OperatorType.LogicalNot, methodInfo); + break; + + // Binary + case "op_Addition": + LoadMethod(OperatorType.Addition, methodInfo); + break; + case "op_Subtraction": + LoadMethod(OperatorType.Subtraction, methodInfo); + break; + case "op_Multiply": + LoadMethod(OperatorType.Multiplication, methodInfo); + break; + case "op_Division": + LoadMethod(OperatorType.Division, methodInfo); + break; + } + } + } + + private void LoadMethod(OperatorType type, MethodInfo method) + { + if (_operatorMethods.TryGetValue(type, out var ls)) + { + ls.Add(new OperatorMethod(method)); + } + else + { + _operatorMethods.Add(type, new List(){new OperatorMethod(method)}); + } + } + + public MethodInfo GetBinaryOperator(OperatorType op, System.Type t1, System.Type t2) + { + if (!_operatorMethods.TryGetValue(op, out var m)) + { + return null; + } + + foreach (var operatorMethod in m) + { + if (operatorMethod.ParameterTypeHashes[0] == t1.GetHashCode() && + operatorMethod.ParameterTypeHashes[1] == t2.GetHashCode()) + return operatorMethod.Info; + + if (operatorMethod.ParameterTypeHashes[1] == t2.GetHashCode() && + operatorMethod.ParameterTypeHashes[0] == t1.GetHashCode()) + return operatorMethod.Info; + } + return null; + } + + public MethodInfo GetUnaryOperator(OperatorType op, System.Type t1) + { + if (!_operatorMethods.TryGetValue(op, out var m)) + { + return null; + } + + foreach (var operatorMethod in m) + { + if (operatorMethod.ParameterTypeHashes[0] == t1.GetHashCode()) + return operatorMethod.Info; + } + return null; } } } \ No newline at end of file diff --git a/Upsilon/Evaluator/Evaluator.cs b/Upsilon/Evaluator/Evaluator.cs index e55df56..f67c26d 100644 --- a/Upsilon/Evaluator/Evaluator.cs +++ b/Upsilon/Evaluator/Evaluator.cs @@ -3,6 +3,7 @@ using System.Collections.Generic; using System.Collections.Immutable; using Upsilon.BaseTypes; using Upsilon.BaseTypes.Number; +using Upsilon.BaseTypes.UserData; using Upsilon.Binder; using Type = Upsilon.BaseTypes.Type; @@ -185,13 +186,53 @@ namespace Upsilon.Evaluator { return ((LuaString) left) + right; } + else if (left.Type == Type.UserData) + { + var ud = (UserData) left; + var (type, failed) = ud.BinaryOperator(left, OperatorType.Addition, right); + if (failed) goto default; + return type; + } goto default; case BoundBinaryOperator.OperatorKind.Subtraction: - return ((Number)left) - ((Number)right); + if (left.Type == Type.Number) + { + return ((Number)left) - ((Number)right); + } + else if (left.Type == Type.UserData) + { + var ud = (UserData) left; + var (type, failed) = ud.BinaryOperator(left, OperatorType.Subtraction, right); + if (failed) goto default; + return type; + } + goto default; case BoundBinaryOperator.OperatorKind.Multiplication: - return ((Number)left) * ((Number)right); + if (left.Type == Type.Number) + { + return ((Number)left) * ((Number)right); + } + else if (left.Type == Type.UserData) + { + var ud = (UserData) left; + var (type, failed) = ud.BinaryOperator(left, OperatorType.Multiplication, right); + if (failed) goto default; + return type; + } + goto default; case BoundBinaryOperator.OperatorKind.Division: - return ((Number)left) / ((Number)right); + if (left.Type == Type.Number) + { + return ((Number)left) / ((Number)right); + } + else if (left.Type == Type.UserData) + { + var ud = (UserData) left; + var (type, failed) = ud.BinaryOperator(left, OperatorType.Division, right); + if (failed) goto default; + return type; + } + goto default; case BoundBinaryOperator.OperatorKind.Equality: return new LuaBoolean(Equals(left, right)); case BoundBinaryOperator.OperatorKind.Inequality: diff --git a/UpsilonTests/UserDataOperatorTests.cs b/UpsilonTests/UserDataOperatorTests.cs index 8afa63a..04f0e2a 100644 --- a/UpsilonTests/UserDataOperatorTests.cs +++ b/UpsilonTests/UserDataOperatorTests.cs @@ -41,6 +41,21 @@ namespace UpsilonTests return new UserDataHelper(a.Value + b); } + public static UserDataHelper operator -(UserDataHelper a, UserDataHelper b) + { + return new UserDataHelper(a.Value - b.Value); + } + + public static UserDataHelper operator *(UserDataHelper a, UserDataHelper b) + { + return new UserDataHelper(a.Value * b.Value); + } + + public static UserDataHelper operator /(UserDataHelper a, UserDataHelper b) + { + return new UserDataHelper(a.Value / b.Value); + } + } #pragma warning restore 414, 649 @@ -55,9 +70,80 @@ end "; var script = new Script(input); Assert.Empty(script.Diagnostics.Messages); - var o1 = new UserDataHelper(100); - var o2 = new UserDataHelper(215); - var result = script.EvaluateFunction("add", new[] {o1, o2}); + var o1 = new UserDataHelper(100); + var o2 = new UserDataHelper(215); + var result = script.EvaluateFunction("add", new[] {o1, o2}); + Assert.Empty(script.Diagnostics.Messages); + Assert.Equal(315, result.Value); } + + [Fact] + public void TestAdditionOverloading() + { + const string input = @" +function add(o1, o2) + return o1 + o2 +end +"; + var script = new Script(input); + Assert.Empty(script.Diagnostics.Messages); + var o1 = new UserDataHelper(100); + const double o2 = 1.5; + var result = script.EvaluateFunction("add", new object[] {o1, o2}); + Assert.Empty(script.Diagnostics.Messages); + Assert.Equal(101.5, result.Value); + } + + [Fact] + public void TestSubtraction() + { + const string input = @" +function subtract(o1, o2) + return o1 - o2 +end +"; + var script = new Script(input); + Assert.Empty(script.Diagnostics.Messages); + var o1 = new UserDataHelper(100); + var o2 = new UserDataHelper(1.5); + var result = script.EvaluateFunction("subtract", new object[] {o1, o2}); + Assert.Empty(script.Diagnostics.Messages); + Assert.Equal(98.5, result.Value); + } + + [Fact] + public void TestMultiplication() + { + const string input = @" +function multiply(o1, o2) + return o1 * o2 +end +"; + var script = new Script(input); + Assert.Empty(script.Diagnostics.Messages); + var o1 = new UserDataHelper(100); + var o2 = new UserDataHelper(4); + var result = script.EvaluateFunction("multiply", new object[] {o1, o2}); + Assert.Empty(script.Diagnostics.Messages); + Assert.Equal(400, result.Value); + } + + [Fact] + public void TestDivision() + { + const string input = @" +function divide(o1, o2) + return o1 / o2 +end +"; + var script = new Script(input); + Assert.Empty(script.Diagnostics.Messages); + var o1 = new UserDataHelper(100); + var o2 = new UserDataHelper(10); + var result = script.EvaluateFunction("divide", new object[] {o1, o2}); + Assert.Empty(script.Diagnostics.Messages); + Assert.Equal(10, result.Value); + } + } } \ No newline at end of file