Handle Binary operator overloading for UserData

This commit is contained in:
Deukhoofd 2018-11-21 13:47:16 +01:00
parent 0702b9f271
commit c627100e9c
No known key found for this signature in database
GPG Key ID: B4C087AC81641654
5 changed files with 246 additions and 14 deletions

View File

@ -43,5 +43,10 @@ namespace Upsilon.BaseTypes.UserData
diagnostics.LogError($"Cannot find member '{s}' on type '{Value.GetType()}'", span);
}
}
public (LuaType Type, bool Failed) BinaryOperator(LuaType par1, OperatorType op, LuaType par2)
{
return _typeInfo.BinaryOperator(Value, par1, op, par2);
}
}
}

View File

@ -24,14 +24,14 @@ namespace Upsilon.BaseTypes.UserData
Methods.Add(commonName, new UserDataMethod(methodInfo));
}
}
_operatorHandler = new UserDataTypeOperators(type);
OperatorHandler = new UserDataTypeOperators(type);
}
private System.Type Type { get; }
private Dictionary<string, FieldInfo> Variables { get; }
private Dictionary<string, PropertyInfo> Properties { get; }
private Dictionary<string, UserDataMethod> Methods { get; }
private UserDataTypeOperators _operatorHandler { get; }
private UserDataTypeOperators OperatorHandler { get; }
public (LuaType Type, bool Failed) Get(object value, string member)
{
@ -72,9 +72,15 @@ namespace Upsilon.BaseTypes.UserData
return true;
}
public (LuaType Type, bool Failed) BinaryOperator(object value, OperatorType op, object value2)
public (LuaType Type, bool Failed) BinaryOperator(object value, LuaType par1, OperatorType op, LuaType par2)
{
return (null, true);
var method = OperatorHandler.GetBinaryOperator(op, par1.GetCSharpType(), par2.GetCSharpType());
if (method == null)
{
return (new LuaNull(), true);
}
return (method.Invoke(value, new[] {par1.ToCSharpObject(), par2.ToCSharpObject()}).ToLuaType(), false);
}
}
}

View File

@ -1,24 +1,118 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Reflection;
namespace Upsilon.BaseTypes.UserData
{
public enum OperatorType
{
UnaryPlus,
UnaryNegation,
LogicalNot,
Addition,
Subtraction,
Multiplication,
Division,
}
public class UserDataTypeOperators
{
private class OperatorKeyData
private class OperatorMethod
{
public MethodInfo Info { get; }
public int[] ParameterTypeHashes { get; }
public OperatorMethod(MethodInfo info)
{
Info = info;
ParameterTypeHashes = info.GetParameters().Select(x => x.ParameterType.GetHashCode()).ToArray();
}
}
public UserDataTypeOperators(System.Type t)
private readonly Dictionary<OperatorType, List<OperatorMethod>> _operatorMethods;
public UserDataTypeOperators(IReflect t)
{
var additionMethod = t.GetMethod("op_Addition", BindingFlags.Static | BindingFlags.Public);
Console.WriteLine(additionMethod);
var staticMethod = t.GetMethods(BindingFlags.Static | BindingFlags.Public);
_operatorMethods = new Dictionary<OperatorType, List<OperatorMethod>>();
foreach (var methodInfo in staticMethod)
{
switch (methodInfo.Name)
{
// Unary
case "op_UnaryPlus":
LoadMethod(OperatorType.UnaryPlus, methodInfo);
break;
case "op_UnaryNegation":
LoadMethod(OperatorType.UnaryNegation, methodInfo);
break;
case "op_LogicalNot":
LoadMethod(OperatorType.LogicalNot, methodInfo);
break;
// Binary
case "op_Addition":
LoadMethod(OperatorType.Addition, methodInfo);
break;
case "op_Subtraction":
LoadMethod(OperatorType.Subtraction, methodInfo);
break;
case "op_Multiply":
LoadMethod(OperatorType.Multiplication, methodInfo);
break;
case "op_Division":
LoadMethod(OperatorType.Division, methodInfo);
break;
}
}
}
private void LoadMethod(OperatorType type, MethodInfo method)
{
if (_operatorMethods.TryGetValue(type, out var ls))
{
ls.Add(new OperatorMethod(method));
}
else
{
_operatorMethods.Add(type, new List<OperatorMethod>(){new OperatorMethod(method)});
}
}
public MethodInfo GetBinaryOperator(OperatorType op, System.Type t1, System.Type t2)
{
if (!_operatorMethods.TryGetValue(op, out var m))
{
return null;
}
foreach (var operatorMethod in m)
{
if (operatorMethod.ParameterTypeHashes[0] == t1.GetHashCode() &&
operatorMethod.ParameterTypeHashes[1] == t2.GetHashCode())
return operatorMethod.Info;
if (operatorMethod.ParameterTypeHashes[1] == t2.GetHashCode() &&
operatorMethod.ParameterTypeHashes[0] == t1.GetHashCode())
return operatorMethod.Info;
}
return null;
}
public MethodInfo GetUnaryOperator(OperatorType op, System.Type t1)
{
if (!_operatorMethods.TryGetValue(op, out var m))
{
return null;
}
foreach (var operatorMethod in m)
{
if (operatorMethod.ParameterTypeHashes[0] == t1.GetHashCode())
return operatorMethod.Info;
}
return null;
}
}
}

View File

@ -3,6 +3,7 @@ using System.Collections.Generic;
using System.Collections.Immutable;
using Upsilon.BaseTypes;
using Upsilon.BaseTypes.Number;
using Upsilon.BaseTypes.UserData;
using Upsilon.Binder;
using Type = Upsilon.BaseTypes.Type;
@ -185,13 +186,53 @@ namespace Upsilon.Evaluator
{
return ((LuaString) left) + right;
}
else if (left.Type == Type.UserData)
{
var ud = (UserData) left;
var (type, failed) = ud.BinaryOperator(left, OperatorType.Addition, right);
if (failed) goto default;
return type;
}
goto default;
case BoundBinaryOperator.OperatorKind.Subtraction:
return ((Number)left) - ((Number)right);
if (left.Type == Type.Number)
{
return ((Number)left) - ((Number)right);
}
else if (left.Type == Type.UserData)
{
var ud = (UserData) left;
var (type, failed) = ud.BinaryOperator(left, OperatorType.Subtraction, right);
if (failed) goto default;
return type;
}
goto default;
case BoundBinaryOperator.OperatorKind.Multiplication:
return ((Number)left) * ((Number)right);
if (left.Type == Type.Number)
{
return ((Number)left) * ((Number)right);
}
else if (left.Type == Type.UserData)
{
var ud = (UserData) left;
var (type, failed) = ud.BinaryOperator(left, OperatorType.Multiplication, right);
if (failed) goto default;
return type;
}
goto default;
case BoundBinaryOperator.OperatorKind.Division:
return ((Number)left) / ((Number)right);
if (left.Type == Type.Number)
{
return ((Number)left) / ((Number)right);
}
else if (left.Type == Type.UserData)
{
var ud = (UserData) left;
var (type, failed) = ud.BinaryOperator(left, OperatorType.Division, right);
if (failed) goto default;
return type;
}
goto default;
case BoundBinaryOperator.OperatorKind.Equality:
return new LuaBoolean(Equals(left, right));
case BoundBinaryOperator.OperatorKind.Inequality:

View File

@ -41,6 +41,21 @@ namespace UpsilonTests
return new UserDataHelper(a.Value + b);
}
public static UserDataHelper operator -(UserDataHelper a, UserDataHelper b)
{
return new UserDataHelper(a.Value - b.Value);
}
public static UserDataHelper operator *(UserDataHelper a, UserDataHelper b)
{
return new UserDataHelper(a.Value * b.Value);
}
public static UserDataHelper operator /(UserDataHelper a, UserDataHelper b)
{
return new UserDataHelper(a.Value / b.Value);
}
}
#pragma warning restore 414, 649
@ -55,9 +70,80 @@ end
";
var script = new Script(input);
Assert.Empty(script.Diagnostics.Messages);
var o1 = new UserDataHelper(100);
var o2 = new UserDataHelper(215);
var result = script.EvaluateFunction("add", new[] {o1, o2});
var o1 = new UserDataHelper(100);
var o2 = new UserDataHelper(215);
var result = script.EvaluateFunction<UserDataHelper>("add", new[] {o1, o2});
Assert.Empty(script.Diagnostics.Messages);
Assert.Equal(315, result.Value);
}
[Fact]
public void TestAdditionOverloading()
{
const string input = @"
function add(o1, o2)
return o1 + o2
end
";
var script = new Script(input);
Assert.Empty(script.Diagnostics.Messages);
var o1 = new UserDataHelper(100);
const double o2 = 1.5;
var result = script.EvaluateFunction<UserDataHelper>("add", new object[] {o1, o2});
Assert.Empty(script.Diagnostics.Messages);
Assert.Equal(101.5, result.Value);
}
[Fact]
public void TestSubtraction()
{
const string input = @"
function subtract(o1, o2)
return o1 - o2
end
";
var script = new Script(input);
Assert.Empty(script.Diagnostics.Messages);
var o1 = new UserDataHelper(100);
var o2 = new UserDataHelper(1.5);
var result = script.EvaluateFunction<UserDataHelper>("subtract", new object[] {o1, o2});
Assert.Empty(script.Diagnostics.Messages);
Assert.Equal(98.5, result.Value);
}
[Fact]
public void TestMultiplication()
{
const string input = @"
function multiply(o1, o2)
return o1 * o2
end
";
var script = new Script(input);
Assert.Empty(script.Diagnostics.Messages);
var o1 = new UserDataHelper(100);
var o2 = new UserDataHelper(4);
var result = script.EvaluateFunction<UserDataHelper>("multiply", new object[] {o1, o2});
Assert.Empty(script.Diagnostics.Messages);
Assert.Equal(400, result.Value);
}
[Fact]
public void TestDivision()
{
const string input = @"
function divide(o1, o2)
return o1 / o2
end
";
var script = new Script(input);
Assert.Empty(script.Diagnostics.Messages);
var o1 = new UserDataHelper(100);
var o2 = new UserDataHelper(10);
var result = script.EvaluateFunction<UserDataHelper>("divide", new object[] {o1, o2});
Assert.Empty(script.Diagnostics.Messages);
Assert.Equal(10, result.Value);
}
}
}