diff --git a/trino-csharp/Trino.Client.Test/TrinoFunctionTests.cs b/trino-csharp/Trino.Client.Test/TrinoFunctionTests.cs new file mode 100644 index 0000000..3ef4aad --- /dev/null +++ b/trino-csharp/Trino.Client.Test/TrinoFunctionTests.cs @@ -0,0 +1,112 @@ +using Trino.Client.Utils; + +namespace Trino.Client.Test +{ + [TestClass] + public class TrinoFunctionTests + { + /// + /// We subclass the TrinoFunction class to easily access the output from BuildFunctionStatement() + /// + class TrinoFunctionTestHelper : TrinoFunction + { + public TrinoFunctionTestHelper(string catalog, string functionName, IList parameters) : base(catalog, functionName, parameters) + { + } + + public string GetStatement() + { + return base.BuildFunctionStatement(); + } + } + + [TestMethod] + public void TestNoParams() + { + var testTrinoFunction = new TrinoFunctionTestHelper("test", "testFunc", []); + var statement = testTrinoFunction.GetStatement(); + + Assert.AreEqual("test.testFunc()", statement); + } + + [TestMethod] + public void TestOneStringParam() + { + var testTrinoFunction = new TrinoFunctionTestHelper("test", "testFunc", ["a"]); + var statement = testTrinoFunction.GetStatement(); + + Assert.AreEqual("test.testFunc('a')", statement); + } + + [TestMethod] + public void TestTwoStringParams() + { + var testTrinoFunction = new TrinoFunctionTestHelper("test", "testFunc", ["a", "b"]); + var statement = testTrinoFunction.GetStatement(); + + Assert.AreEqual("test.testFunc('a', 'b')", statement); + } + + [TestMethod] + public void TestNullCatalogue() + { + var testTrinoFunction = new TrinoFunctionTestHelper(null!, "testFunc", []); + var statement = testTrinoFunction.GetStatement(); + + Assert.AreEqual("testFunc()", statement); + } + + [TestMethod] + public void TestEmptyCatalogue() + { + var testTrinoFunction = new TrinoFunctionTestHelper(string.Empty, "testFunc", []); + var statement = testTrinoFunction.GetStatement(); + + Assert.AreEqual("testFunc()", statement); + } + + [TestMethod] + public void TestIntegralParams() + { + IList args = [(sbyte)1, (byte)2, (short)3, (ushort)4, 5, 6U, 7L, 8UL, (nint)9, (nuint)10]; + var testTrinoFunction = new TrinoFunctionTestHelper(string.Empty, "testFunc", args); + var statement = testTrinoFunction.GetStatement(); + + Assert.AreEqual("testFunc(1, 2, 3, 4, 5, 6, 7, 8, 9, 10)", statement); + } + + [TestMethod] + public void TestFractionalParams() + { + IList args = [1.5f, 2.5d, 3.5m]; + var testTrinoFunction = new TrinoFunctionTestHelper(string.Empty, "testFunc", args); + var statement = testTrinoFunction.GetStatement(); + + Assert.AreEqual("testFunc(1.5, 2.5, 3.5)", statement); + } + + [TestMethod] + public void TestBooleanParams() + { + IList args = [true, false]; + var testTrinoFunction = new TrinoFunctionTestHelper(string.Empty, "testFunc", args); + var statement = testTrinoFunction.GetStatement(); + + Assert.AreEqual("testFunc(1, 0)", statement); + } + + [TestMethod] + public void TestDateParams() + { + IList args = [ + new DateTime(2025, 4, 3), + new DateTime(2025, 4, 3, 12, 0, 0), + new DateTimeOffset(2025, 4, 3, 12, 0, 0, new TimeSpan(1, 0, 0)) + ]; + var testTrinoFunction = new TrinoFunctionTestHelper(string.Empty, "testFunc", args); + var statement = testTrinoFunction.GetStatement(); + + Assert.AreEqual("testFunc('2025-04-03T00:00:00', '2025-04-03T12:00:00', '2025-04-03T12:00:00')", statement); + } + } +} diff --git a/trino-csharp/Trino.Client/Utils/TrinoFunction.cs b/trino-csharp/Trino.Client/Utils/TrinoFunction.cs index 932354e..14a91b9 100644 --- a/trino-csharp/Trino.Client/Utils/TrinoFunction.cs +++ b/trino-csharp/Trino.Client/Utils/TrinoFunction.cs @@ -1,5 +1,5 @@ +using System; using System.Collections.Generic; -using System.Text; using System.Threading.Tasks; namespace Trino.Client.Utils @@ -25,37 +25,63 @@ public virtual Task ExecuteAsync(ClientSession session) protected virtual string BuildFunctionStatement() { - StringBuilder stringBuilder = new StringBuilder(); + var fullFunctionName = BuildFullFunctionName(); + var parameterString = string.Join(", ", FormatParameters()); + + return $"{fullFunctionName}({parameterString})"; + } + + private string BuildFullFunctionName() + { if (!string.IsNullOrEmpty(catalog)) { - stringBuilder.Append(this.catalog); - stringBuilder.Append("."); + return $"{catalog}.{functionName}"; } - stringBuilder.Append(this.functionName); - stringBuilder.Append("("); - for (int i = 0; i < Parameters.Count; i++) + return functionName; + } + + private IEnumerable FormatParameters() + { + foreach (var parameter in Parameters) { - if (i > 0) + if (IsNumeric(parameter)) { - stringBuilder.Append(", "); + yield return parameter.ToString(); } - - // if parameter is a digit, do not quote it - if (Parameters[i] is int || Parameters[i] is long || Parameters[i] is float || Parameters[i] is double) + else if (parameter is bool) { - stringBuilder.Append(Parameters[i]); + yield return (bool)parameter ? "1" : "0"; + } + else if (parameter is DateTime) + { + yield return QuoteParameter(((DateTime)parameter).ToString("s")); + } + else if (parameter is DateTimeOffset) + { + yield return QuoteParameter(((DateTimeOffset)parameter).ToString("s")); } else { - stringBuilder.Append("'"); - stringBuilder.Append(Parameters[i]); - stringBuilder.Append("'"); + yield return QuoteParameter(parameter.ToString()); } } - stringBuilder.Append(")"); + } - return stringBuilder.ToString(); + private static bool IsNumeric(object parameter) + { + return parameter is byte || parameter is sbyte + || parameter is short || parameter is ushort + || parameter is int || parameter is uint + || parameter is long || parameter is ulong + || parameter is IntPtr || parameter is UIntPtr + || parameter is float || parameter is double + || parameter is decimal; + } + + private static string QuoteParameter(string formattedParameter) + { + return $"'{formattedParameter}'"; } } }