1010
1111import base64
1212import datetime
13+ import io
14+ import itertools
15+ import numpy as np
16+ import pyarrow as pa
17+ import pyarrow .json
1318import re
1419from decimal import Decimal
1520from ssl import CERT_NONE , CERT_OPTIONAL , CERT_REQUIRED , create_default_context
4045
4146_logger = logging .getLogger (__name__ )
4247
43- _TIMESTAMP_PATTERN = re .compile (r'(\d+-\d+-\d+ \d+:\d+:\d+(\.\d{,6})?)' )
48+ _TIMESTAMP_PATTERN = re .compile (r'(\d+-\d+-\d+ \d+:\d+:\d+(\.\d{,9})?)' )
49+ _INTERVAL_DAY_TIME_PATTERN = re .compile (r'(\d+) (\d+):(\d+):(\d+(?:.\d+)?)' )
4450
4551ssl_cert_parameter_map = {
4652 "none" : CERT_NONE ,
@@ -106,9 +112,36 @@ def _parse_timestamp(value):
106112 value = None
107113 return value
108114
115+ def _parse_date (value ):
116+ if value :
117+ format = '%Y-%m-%d'
118+ value = datetime .datetime .strptime (value , format ).date ()
119+ else :
120+ value = None
121+ return value
109122
110- TYPES_CONVERTER = {"DECIMAL_TYPE" : Decimal ,
111- "TIMESTAMP_TYPE" : _parse_timestamp }
123+ def _parse_interval_day_time (value ):
124+ if value :
125+ match = _INTERVAL_DAY_TIME_PATTERN .match (value )
126+ if match :
127+ days = int (match .group (1 ))
128+ hours = int (match .group (2 ))
129+ minutes = int (match .group (3 ))
130+ seconds = float (match .group (4 ))
131+ value = datetime .timedelta (days = days , hours = hours , minutes = minutes , seconds = seconds )
132+ else :
133+ raise Exception (
134+ 'Cannot convert "{}" into an interval_day_time' .format (value ))
135+ else :
136+ value = None
137+ return value
138+
139+ TYPES_CONVERTER = {
140+ "DECIMAL_TYPE" : Decimal ,
141+ "TIMESTAMP_TYPE" : _parse_timestamp ,
142+ "DATE_TYPE" : _parse_date ,
143+ "INTERVAL_DAY_TIME_TYPE" : _parse_interval_day_time ,
144+ }
112145
113146
114147class HiveParamEscaper (common .ParamEscaper ):
@@ -488,7 +521,50 @@ def cancel(self):
488521 response = self ._connection .client .CancelOperation (req )
489522 _check_status (response )
490523
491- def _fetch_more (self ):
524+ def fetchone (self , schema = []):
525+ return self .fetchmany (1 , schema )
526+
527+ def fetchall (self , schema = []):
528+ return self .fetchmany (- 1 , schema )
529+
530+ def fetchmany (self , size = None , schema = []):
531+ if size is None :
532+ size = self .arraysize
533+
534+ if self ._state == self ._STATE_NONE :
535+ raise exc .ProgrammingError ("No query yet" )
536+
537+ if size == - 1 :
538+ # Fetch everything
539+ self ._fetch_while (lambda : self ._state != self ._STATE_FINISHED , schema )
540+ else :
541+ self ._fetch_while (lambda :
542+ (self ._state != self ._STATE_FINISHED ) and
543+ (self ._data is None or self ._data .num_rows < size ),
544+ schema
545+ )
546+
547+ if not self ._data :
548+ return None
549+
550+ if size == - 1 :
551+ # Fetch everything
552+ size = self ._data .num_rows
553+ else :
554+ size = min (size , self ._data .num_rows )
555+
556+ self ._rownumber += size
557+ rows = self ._data [:size ]
558+
559+ if size == self ._data .num_rows :
560+ # Fetch everything
561+ self ._data = None
562+ else :
563+ self ._data = self ._data [size :]
564+
565+ return rows
566+
567+ def _fetch_more (self , ext_schema ):
492568 """Send another TFetchResultsReq and update state"""
493569 assert (self ._state == self ._STATE_RUNNING ), "Should be running when in _fetch_more"
494570 assert (self ._operationHandle is not None ), "Should have an op handle in _fetch_more"
@@ -503,15 +579,21 @@ def _fetch_more(self):
503579 _check_status (response )
504580 schema = self .description
505581 assert not response .results .rows , 'expected data in columnar format'
506- columns = [_unwrap_column (col , col_schema [1 ]) for col , col_schema in
507- zip (response .results .columns , schema )]
508- new_data = list ( zip ( * columns ))
509- self . _data += new_data
582+ columns = [_unwrap_column (col , col_schema [1 ], e_schema ) for col , col_schema , e_schema in
583+ itertools . zip_longest (response .results .columns , schema , ext_schema )]
584+ names = [ col [ 0 ] for col in schema ]
585+ new_data = pa . Table . from_batches ([ pa . RecordBatch . from_arrays ( columns , names = names )])
510586 # response.hasMoreRows seems to always be False, so we instead check the number of rows
511587 # https://github.com/apache/hive/blob/release-1.2.1/service/src/java/org/apache/hive/service/cli/thrift/ThriftCLIService.java#L678
512588 # if not response.hasMoreRows:
513- if not new_data :
589+ if new_data . num_rows == 0 :
514590 self ._state = self ._STATE_FINISHED
591+ return
592+
593+ if self ._data is None :
594+ self ._data = new_data
595+ else :
596+ self ._data = pa .concat_tables ([self ._data , new_data ])
515597
516598 def poll (self , get_progress_update = True ):
517599 """Poll for and return the raw status data provided by the Hive Thrift REST API.
@@ -585,21 +667,55 @@ def fetch_logs(self):
585667#
586668
587669
588- def _unwrap_column (col , type_ = None ):
670+ def _unwrap_column (col , type_ = None , schema = None ):
589671 """Return a list of raw values from a TColumn instance."""
590672 for attr , wrapper in iteritems (col .__dict__ ):
591673 if wrapper is not None :
592- result = wrapper .values
593- nulls = wrapper .nulls # bit set describing what's null
594- assert isinstance (nulls , bytes )
595- for i , char in enumerate (nulls ):
596- byte = ord (char ) if sys .version_info [0 ] == 2 else char
597- for b in range (8 ):
598- if byte & (1 << b ):
599- result [i * 8 + b ] = None
600- converter = TYPES_CONVERTER .get (type_ , None )
601- if converter and type_ :
602- result = [converter (row ) if row else row for row in result ]
674+ if attr in ['boolVal' , 'byteVal' , 'i16Val' , 'i32Val' , 'i64Val' , 'doubleVal' ]:
675+ values = wrapper .values
676+ # unpack nulls as a byte array
677+ nulls = np .unpackbits (np .frombuffer (wrapper .nulls , dtype = 'uint8' )).view (bool )
678+ # override a full mask as trailing False values are not sent
679+ mask = np .zeros (values .shape , dtype = '?' )
680+ end = min (len (mask ), len (nulls ))
681+ mask [:end ] = nulls [:end ]
682+
683+ # float values are transferred as double
684+ if type_ == 'FLOAT_TYPE' :
685+ values = values .astype ('>f4' )
686+
687+ result = pa .array (values .byteswap ().view (values .dtype .newbyteorder ()), mask = mask )
688+
689+ else :
690+ result = wrapper .values
691+ nulls = wrapper .nulls # bit set describing what's null
692+ if len (result ) == 0 :
693+ return pa .array ([])
694+ assert isinstance (nulls , bytes )
695+ for i , char in enumerate (nulls ):
696+ byte = ord (char ) if sys .version_info [0 ] == 2 else char
697+ for b in range (8 ):
698+ if byte & (1 << b ):
699+ result [i * 8 + b ] = None
700+ converter = TYPES_CONVERTER .get (type_ , None )
701+ if converter and type_ :
702+ result = [converter (row ) if row else row for row in result ]
703+
704+ if type_ in ['ARRAY_TYPE' , 'MAP_TYPE' , 'STRUCT_TYPE' ]:
705+ fd = io .BytesIO ()
706+ for row in result :
707+ if row is None :
708+ row = 'null'
709+ fd .write (f'{{"c":{ row } }}\n ' .encode ('utf8' ))
710+ fd .seek (0 )
711+
712+ if schema == None :
713+ # NOTE: JSON map conversion (from the original struct) is not supported
714+ result = pa .json .read_json (fd , parse_options = None )[0 ].combine_chunks ()
715+ else :
716+ sch = pa .schema ([('c' , schema )])
717+ opts = pa .json .ParseOptions (explicit_schema = sch )
718+ result = pa .json .read_json (fd , parse_options = opts )[0 ].combine_chunks ()
603719 return result
604720 raise DataError ("Got empty column value {}" .format (col )) # pragma: no cover
605721
0 commit comments