Skip to content

Commit 30a4c28

Browse files
committed
Add support for ±Infinity and NaN in NUMBER
Extend `NUMBER` type to allow representing positive/negative infinity and not-a-number values. This will allow e.g. the following: - safely convert all values of `real` and `double` to `number` - read PostgreSQL's `decimal` values (PostgreSQL `decimal` can hold infinity, and the unconstrained decimal (with dynamic scale) can also hold not-a-number)
1 parent 458d760 commit 30a4c28

File tree

25 files changed

+1506
-147
lines changed

25 files changed

+1506
-147
lines changed

client/trino-client/src/main/java/io/trino/client/JsonDecodingUtils.java

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@
6565
import static io.trino.client.ClientStandardTypes.UUID;
6666
import static io.trino.client.ClientStandardTypes.VARBINARY;
6767
import static io.trino.client.ClientStandardTypes.VARCHAR;
68+
import static java.lang.Double.parseDouble;
6869
import static java.lang.String.format;
6970
import static java.nio.charset.StandardCharsets.UTF_8;
7071
import static java.util.Collections.unmodifiableList;
@@ -276,7 +277,12 @@ public Object decode(JsonParser parser)
276277
// TODO maybe we could send numbers without base64. This probably requires client capabilities.
277278
// Old clients apply base64 decoding to any type they don't recognize.
278279
// TODO If Base64 stays, `parser.getBinaryValue(Base64Variants.MIME)` might be a better way to parse it.
279-
return new BigDecimal(new String(Base64.getDecoder().decode(parser.getValueAsString()), UTF_8));
280+
String stringified = new String(Base64.getDecoder().decode(parser.getValueAsString()), UTF_8);
281+
double doubleValue = parseDouble(stringified);
282+
if (!Double.isFinite(doubleValue)) {
283+
return doubleValue;
284+
}
285+
return new BigDecimal(stringified);
280286
}
281287
}
282288

core/trino-main/src/main/java/io/trino/operator/JoinDomainBuilder.java

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
import static io.trino.spi.function.InvocationConvention.simpleConvention;
4444
import static io.trino.spi.predicate.Range.range;
4545
import static io.trino.spi.type.DoubleType.DOUBLE;
46+
import static io.trino.spi.type.NumberType.NUMBER;
4647
import static io.trino.spi.type.RealType.REAL;
4748
import static io.trino.spi.type.TypeUtils.isFloatingPointNaN;
4849
import static io.trino.spi.type.TypeUtils.readNativeValue;
@@ -113,8 +114,8 @@ public JoinDomainBuilder(
113114
this.maxFilterSizeInBytes = maxFilterSize.toBytes();
114115
this.notifyStateChange = requireNonNull(notifyStateChange, "notifyStateChange is null");
115116

116-
// Skipping DOUBLE and REAL in collectMinMaxValues to avoid dealing with NaN values
117-
this.collectMinMax = minMaxEnabled && type.isOrderable() && type != DOUBLE && type != REAL;
117+
// Skipping REAL, DOUBLE and NUMBER in collectMinMaxValues to avoid dealing with NaN values
118+
this.collectMinMax = minMaxEnabled && type.isOrderable() && type != REAL && type != DOUBLE && type != NUMBER;
118119

119120
MethodHandle readOperator = typeOperators.getReadValueOperator(type, simpleConvention(NULLABLE_RETURN, FLAT));
120121
readOperator = readOperator.asType(readOperator.type().changeReturnType(Object.class));

core/trino-main/src/main/java/io/trino/sql/planner/DomainTranslator.java

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,6 @@
3434
import io.trino.spi.predicate.SortedRangeSet;
3535
import io.trino.spi.predicate.TupleDomain;
3636
import io.trino.spi.predicate.ValueSet;
37-
import io.trino.spi.type.DoubleType;
38-
import io.trino.spi.type.RealType;
3937
import io.trino.spi.type.Type;
4038
import io.trino.spi.type.VarcharType;
4139
import io.trino.sql.InterpretedFunctionInvoker;
@@ -86,6 +84,9 @@
8684
import static io.trino.spi.function.OperatorType.SATURATED_FLOOR_CAST;
8785
import static io.trino.spi.type.BooleanType.BOOLEAN;
8886
import static io.trino.spi.type.DateType.DATE;
87+
import static io.trino.spi.type.DoubleType.DOUBLE;
88+
import static io.trino.spi.type.NumberType.NUMBER;
89+
import static io.trino.spi.type.RealType.REAL;
8990
import static io.trino.spi.type.TypeUtils.isFloatingPointNaN;
9091
import static io.trino.sql.ir.Booleans.FALSE;
9192
import static io.trino.sql.ir.Booleans.TRUE;
@@ -226,7 +227,7 @@ private List<Expression> extractDisjuncts(Type type, Ranges ranges, Reference re
226227
That result would be further translated to expression "xxx <> 1.0", which is satisfied by NaN.
227228
To avoid error, in such case the ranges are not optimised.
228229
*/
229-
if (type instanceof RealType || type instanceof DoubleType) {
230+
if (type == REAL || type == DOUBLE || type == NUMBER) {
230231
boolean originalRangeIsAll = orderedRanges.stream().anyMatch(Range::isAll);
231232
boolean coalescedRangeIsAll = originalUnionSingleValues.stream().anyMatch(Range::isAll);
232233
if (!originalRangeIsAll && coalescedRangeIsAll) {
@@ -401,7 +402,7 @@ protected ExtractionResult visitLogical(Logical node, Boolean complement)
401402
remainingExpression = residuals.get(0);
402403
}
403404
else if (matchingSingleSymbolDomains) {
404-
// Types REAL and DOUBLE require special handling because they include NaN value. In this case, we cannot rely on the union of domains.
405+
// Types REAL, DOUBLE and NUMBER require special handling because they include NaN value. In this case, we cannot rely on the union of domains.
405406
// That is because domains covering the value set partially might union up to a domain covering the whole value set.
406407
// While the component domains didn't include NaN, the resulting domain could be further translated to predicate "TRUE" or "a IS NOT NULL",
407408
// which is satisfied by NaN. So during domain union, NaN might be implicitly added.
@@ -422,7 +423,7 @@ else if (matchingSingleSymbolDomains) {
422423
boolean unionedDomainContainsNaN = columnUnionedTupleDomain.isAll() ||
423424
(columnUnionedTupleDomain.getDomains().isPresent() &&
424425
getOnlyElement(columnUnionedTupleDomain.getDomains().get().values()).getValues().isAll());
425-
boolean implicitlyAddedNaN = (type instanceof RealType || type instanceof DoubleType) &&
426+
boolean implicitlyAddedNaN = (type == REAL || type == DOUBLE || type == NUMBER) &&
426427
tupleDomains.stream().noneMatch(TupleDomain::isAll) &&
427428
unionedDomainContainsNaN;
428429
if (!implicitlyAddedNaN) {
@@ -677,7 +678,7 @@ private static Optional<Domain> extractOrderableDomain(Comparison.Operator compa
677678
checkArgument(value != null);
678679

679680
// Handle orderable types which do not have NaN.
680-
if (!(type instanceof DoubleType) && !(type instanceof RealType)) {
681+
if (type != REAL && type != DOUBLE && type != NUMBER) {
681682
return switch (comparisonOperator) {
682683
case EQUAL -> Optional.of(Domain.create(complementIfNecessary(ValueSet.ofRanges(Range.equal(type, value)), complement), false));
683684
case IDENTICAL -> Optional.of(Domain.create(complementIfNecessary(ValueSet.ofRanges(Range.equal(type, value)), complement), complement));
@@ -911,7 +912,7 @@ private Optional<ExtractionResult> processSimpleInPredicate(In node, Boolean com
911912
// in case of IN, NULL on the right results with NULL comparison result (effectively false in predicate context), so can be ignored, as the
912913
// comparison results are OR-ed
913914
}
914-
else if (type instanceof RealType || type instanceof DoubleType) {
915+
else if (type == REAL || type == DOUBLE || type == NUMBER) {
915916
// NaN can be ignored: it always compares to false, as if it was not among IN's values
916917
if (!isFloatingPointNaN(type, constant.value())) {
917918
if (complement) {

core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushDownNegationsExpressionRewriter.java

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,6 @@
1515

1616
import com.google.common.collect.ImmutableList;
1717
import io.trino.metadata.ResolvedFunction;
18-
import io.trino.spi.type.DoubleType;
19-
import io.trino.spi.type.RealType;
2018
import io.trino.spi.type.Type;
2119
import io.trino.sql.ir.Call;
2220
import io.trino.sql.ir.Comparison;
@@ -30,6 +28,9 @@
3028

3129
import static com.google.common.collect.ImmutableList.toImmutableList;
3230
import static io.trino.metadata.GlobalFunctionCatalog.builtinFunctionName;
31+
import static io.trino.spi.type.DoubleType.DOUBLE;
32+
import static io.trino.spi.type.NumberType.NUMBER;
33+
import static io.trino.spi.type.RealType.REAL;
3334
import static io.trino.sql.ir.Comparison.Operator.GREATER_THAN;
3435
import static io.trino.sql.ir.Comparison.Operator.GREATER_THAN_OR_EQUAL;
3536
import static io.trino.sql.ir.Comparison.Operator.IDENTICAL;
@@ -89,7 +90,7 @@ public Expression rewriteCall(Call node, Void context, ExpressionTreeRewriter<Vo
8990

9091
private boolean typeHasNaN(Type type)
9192
{
92-
return type instanceof DoubleType || type instanceof RealType;
93+
return type == REAL || type == DOUBLE || type == NUMBER;
9394
}
9495
}
9596
}

core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/UnwrapCastInComparison.java

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,7 @@
2525
import io.trino.spi.type.CharType;
2626
import io.trino.spi.type.DateType;
2727
import io.trino.spi.type.DecimalType;
28-
import io.trino.spi.type.DoubleType;
2928
import io.trino.spi.type.LongTimestampWithTimeZone;
30-
import io.trino.spi.type.RealType;
3129
import io.trino.spi.type.TimeWithTimeZoneType;
3230
import io.trino.spi.type.TimeZoneKey;
3331
import io.trino.spi.type.TimestampType;
@@ -65,6 +63,7 @@
6563
import static io.trino.spi.type.DateType.DATE;
6664
import static io.trino.spi.type.DoubleType.DOUBLE;
6765
import static io.trino.spi.type.IntegerType.INTEGER;
66+
import static io.trino.spi.type.NumberType.NUMBER;
6867
import static io.trino.spi.type.RealType.REAL;
6968
import static io.trino.spi.type.Timestamps.PICOSECONDS_PER_NANOSECOND;
7069
import static io.trino.spi.type.TypeUtils.isFloatingPointNaN;
@@ -481,7 +480,7 @@ private Object coerce(Object value, ResolvedFunction coercion)
481480

482481
private boolean typeHasNaN(Type type)
483482
{
484-
return type instanceof DoubleType || type instanceof RealType;
483+
return type == REAL || type == DOUBLE || type == NUMBER;
485484
}
486485

487486
private int compare(Type type, Object first, Object second)

core/trino-main/src/main/java/io/trino/testing/MaterializedResult.java

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@
5959
import static com.google.common.collect.ImmutableList.toImmutableList;
6060
import static io.trino.spi.type.Timestamps.PICOSECONDS_PER_NANOSECOND;
6161
import static io.trino.spi.type.Timestamps.roundDiv;
62+
import static java.lang.Double.parseDouble;
6263
import static java.util.Objects.requireNonNull;
6364
import static java.util.stream.Collectors.toList;
6465
import static java.util.stream.Collectors.toSet;
@@ -324,7 +325,13 @@ private static MaterializedRow convertToTestTypes(MaterializedRow trinoRow)
324325
case SqlTimestamp sqlTimestamp -> sqlTimestamp.toLocalDateTime();
325326
case SqlTimestampWithTimeZone sqlTimestampWithTimeZone -> sqlTimestampWithTimeZone.toZonedDateTime();
326327
case SqlDecimal sqlDecimal -> sqlDecimal.toBigDecimal();
327-
case SqlNumber sqlNumber -> new BigDecimal(sqlNumber.stringified());
328+
case SqlNumber sqlNumber -> {
329+
double doubleValue = parseDouble(sqlNumber.stringified());
330+
if (!Double.isFinite(doubleValue)) {
331+
yield doubleValue;
332+
}
333+
yield new BigDecimal(sqlNumber.stringified());
334+
}
328335
default -> trinoValue;
329336
};
330337
convertedValues.add(convertedValue);

core/trino-main/src/main/java/io/trino/type/DecimalCasts.java

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939

4040
import java.io.IOException;
4141
import java.math.BigDecimal;
42+
import java.util.Optional;
4243

4344
import static io.airlift.slice.Slices.utf8Slice;
4445
import static io.trino.spi.StandardErrorCode.INVALID_CAST_ARGUMENT;
@@ -248,7 +249,7 @@ public static Int128 bigintToLongDecimal(long value, long precision, long scale,
248249
{
249250
try {
250251
Int128 result = multiply(tenToScale, value);
251-
if (Decimals.overflows(result, (int) precision)) {
252+
if (overflows(result, (int) precision)) {
252253
throw new TrinoException(INVALID_CAST_ARGUMENT, format("Cannot cast BIGINT '%s' to DECIMAL(%s, %s)", value, precision, scale));
253254
}
254255
return result;
@@ -306,7 +307,7 @@ public static Int128 integerToLongDecimal(long value, long precision, long scale
306307
{
307308
try {
308309
Int128 result = multiply(tenToScale, value);
309-
if (Decimals.overflows(result, (int) precision)) {
310+
if (overflows(result, (int) precision)) {
310311
throw new TrinoException(INVALID_CAST_ARGUMENT, format("Cannot cast INTEGER '%s' to DECIMAL(%s, %s)", value, precision, scale));
311312
}
312313
return result;
@@ -365,7 +366,7 @@ public static Int128 smallintToLongDecimal(long value, long precision, long scal
365366
{
366367
try {
367368
Int128 result = multiply(tenToScale, value);
368-
if (Decimals.overflows(result, (int) precision)) {
369+
if (overflows(result, (int) precision)) {
369370
throw new TrinoException(INVALID_CAST_ARGUMENT, format("Cannot cast SMALLINT '%s' to DECIMAL(%s, %s)", value, precision, scale));
370371
}
371372
return result;
@@ -423,7 +424,7 @@ public static Int128 tinyintToLongDecimal(long value, long precision, long scale
423424
{
424425
try {
425426
Int128 result = multiply(tenToScale, value);
426-
if (Decimals.overflows(result, (int) precision)) {
427+
if (overflows(result, (int) precision)) {
427428
throw new TrinoException(INVALID_CAST_ARGUMENT, format("Cannot cast TINYINT '%s' to DECIMAL(%s, %s)", value, precision, scale));
428429
}
429430
return result;
@@ -498,7 +499,8 @@ public static TrinoNumber longDecimalToNumber(Int128 decimal, long precision, lo
498499
@UsedByGeneratedCode
499500
public static long numberToShortDecimal(TrinoNumber value, long precision, long scale, long tenToScale)
500501
{
501-
BigDecimal bigDecimal = value.toBigDecimal();
502+
BigDecimal bigDecimal = numberToBigDecimal(value)
503+
.orElseThrow(() -> new TrinoException(INVALID_CAST_ARGUMENT, format("Cannot cast NUMBER '%s' to DECIMAL(%s, %s)", value, precision, scale)));
502504
BigDecimal result;
503505
try {
504506
result = bigDecimal.setScale(DecimalConversions.intScale(scale), HALF_UP);
@@ -517,7 +519,8 @@ public static long numberToShortDecimal(TrinoNumber value, long precision, long
517519
@UsedByGeneratedCode
518520
public static Int128 numberToLongDecimal(TrinoNumber value, long precision, long scale, Int128 tenToScale)
519521
{
520-
BigDecimal bigDecimal = value.toBigDecimal();
522+
BigDecimal bigDecimal = numberToBigDecimal(value)
523+
.orElseThrow(() -> new TrinoException(INVALID_CAST_ARGUMENT, format("Cannot cast NUMBER '%s' to DECIMAL(%s, %s)", value, precision, scale)));
521524
BigDecimal result;
522525
try {
523526
result = bigDecimal.setScale(DecimalConversions.intScale(scale), HALF_UP);
@@ -533,6 +536,14 @@ public static Int128 numberToLongDecimal(TrinoNumber value, long precision, long
533536
return Int128.valueOf(result.unscaledValue());
534537
}
535538

539+
private static Optional<BigDecimal> numberToBigDecimal(TrinoNumber value)
540+
{
541+
return switch (value.toBigDecimal()) {
542+
case TrinoNumber.NotANumber _, TrinoNumber.Infinity _ -> Optional.empty();
543+
case TrinoNumber.BigDecimalValue(BigDecimal bigDecimal) -> Optional.of(bigDecimal);
544+
};
545+
}
546+
536547
@UsedByGeneratedCode
537548
public static Slice shortDecimalToVarchar(long decimal, long scale, long varcharLength)
538549
{

core/trino-main/src/main/java/io/trino/type/DoubleOperators.java

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,9 @@
2424
import io.trino.spi.function.ScalarOperator;
2525
import io.trino.spi.function.SqlType;
2626
import io.trino.spi.type.StandardTypes;
27+
import io.trino.spi.type.TrinoNumber;
2728

29+
import java.math.BigDecimal;
2830
import java.text.DecimalFormat;
2931
import java.text.DecimalFormatSymbols;
3032

@@ -174,6 +176,19 @@ public static long castToReal(@SqlType(StandardTypes.DOUBLE) double value)
174176
return floatToRawIntBits((float) value);
175177
}
176178

179+
@ScalarOperator(CAST)
180+
@SqlType(StandardTypes.NUMBER)
181+
public static TrinoNumber castToNumber(@SqlType(StandardTypes.DOUBLE) double value)
182+
{
183+
if (Double.isNaN(value)) {
184+
return TrinoNumber.from(new TrinoNumber.NotANumber());
185+
}
186+
if (Double.isInfinite(value)) {
187+
return TrinoNumber.from(new TrinoNumber.Infinity(value < 0.0));
188+
}
189+
return TrinoNumber.from(BigDecimal.valueOf(value));
190+
}
191+
177192
@ScalarOperator(CAST)
178193
@LiteralParameters("x")
179194
@SqlType("varchar(x)")

0 commit comments

Comments
 (0)