|
| 1 | +import os |
| 2 | + |
| 3 | +from pyspark import keyword_only |
| 4 | +from pyspark.ml import Transformer |
| 5 | +from pyspark.ml.param.shared import ( |
| 6 | + HasInputCol, |
| 7 | + HasOutputCol, |
| 8 | + Param, |
| 9 | + Params, |
| 10 | + TypeConverters, |
| 11 | +) |
| 12 | +from pyspark.ml.util import DefaultParamsReadable, DefaultParamsWritable |
| 13 | +from pyspark.sql import Row, SparkSession |
| 14 | +from pyspark.sql.types import StringType, StructField, StructType |
| 15 | +from transformers import AutoModelForCausalLM, AutoTokenizer |
| 16 | + |
| 17 | + |
| 18 | +class _PeekableIterator: |
| 19 | + def __init__(self, iterable): |
| 20 | + self._iterator = iter(iterable) |
| 21 | + self._cache = [] |
| 22 | + |
| 23 | + def __iter__(self): |
| 24 | + return self |
| 25 | + |
| 26 | + def __next__(self): |
| 27 | + if self._cache: |
| 28 | + return self._cache.pop(0) |
| 29 | + else: |
| 30 | + return next(self._iterator) |
| 31 | + |
| 32 | + def peek(self, n=1): |
| 33 | + """Peek at the next n elements without consuming them.""" |
| 34 | + while len(self._cache) < n: |
| 35 | + try: |
| 36 | + self._cache.append(next(self._iterator)) |
| 37 | + except StopIteration: |
| 38 | + break |
| 39 | + if n == 1: |
| 40 | + return self._cache[0] if self._cache else None |
| 41 | + else: |
| 42 | + return self._cache[:n] |
| 43 | + |
| 44 | + |
| 45 | +class _ModelParam: |
| 46 | + def __init__(self, **kwargs): |
| 47 | + self.param = {} |
| 48 | + self.param.update(kwargs) |
| 49 | + |
| 50 | + def get_param(self): |
| 51 | + return self.param |
| 52 | + |
| 53 | + |
| 54 | +class _ModelConfig: |
| 55 | + def __init__(self, **kwargs): |
| 56 | + self.config = {} |
| 57 | + self.config.update(kwargs) |
| 58 | + |
| 59 | + def get_config(self): |
| 60 | + return self.config |
| 61 | + |
| 62 | + def set_config(self, **kwargs): |
| 63 | + self.config.update(kwargs) |
| 64 | + |
| 65 | + |
| 66 | +def broadcast_model(cachePath, modelConfig): |
| 67 | + bc_computable = _BroadcastableModel(cachePath, modelConfig) |
| 68 | + sc = SparkSession.builder.getOrCreate().sparkContext |
| 69 | + return sc.broadcast(bc_computable) |
| 70 | + |
| 71 | + |
| 72 | +class _BroadcastableModel: |
| 73 | + def __init__(self, model_path=None, model_config=None): |
| 74 | + self.model_path = model_path |
| 75 | + self.model = None |
| 76 | + self.tokenizer = None |
| 77 | + self.model_config = model_config |
| 78 | + |
| 79 | + def load_model(self): |
| 80 | + if self.model_path and os.path.exists(self.model_path): |
| 81 | + model_config = self.model_config.get_config() |
| 82 | + self.model = AutoModelForCausalLM.from_pretrained( |
| 83 | + self.model_path, local_files_only=True, **model_config |
| 84 | + ) |
| 85 | + self.tokenizer = AutoTokenizer.from_pretrained( |
| 86 | + self.model_path, local_files_only=True |
| 87 | + ) |
| 88 | + else: |
| 89 | + raise ValueError(f"Model path {self.model_path} does not exist.") |
| 90 | + |
| 91 | + def __getstate__(self): |
| 92 | + return {"model_path": self.model_path, "model_config": self.model_config} |
| 93 | + |
| 94 | + def __setstate__(self, state): |
| 95 | + self.model_path = state.get("model_path") |
| 96 | + self.model_config = state.get("model_config") |
| 97 | + self.model = None |
| 98 | + self.tokenizer = None |
| 99 | + if self.model_path: |
| 100 | + self.load_model() |
| 101 | + |
| 102 | + |
| 103 | +class HuggingFaceCausalLM( |
| 104 | + Transformer, HasInputCol, HasOutputCol, DefaultParamsReadable, DefaultParamsWritable |
| 105 | +): |
| 106 | + |
| 107 | + modelName = Param( |
| 108 | + Params._dummy(), |
| 109 | + "modelName", |
| 110 | + "huggingface causal lm model name", |
| 111 | + typeConverter=TypeConverters.toString, |
| 112 | + ) |
| 113 | + inputCol = Param( |
| 114 | + Params._dummy(), |
| 115 | + "inputCol", |
| 116 | + "input column", |
| 117 | + typeConverter=TypeConverters.toString, |
| 118 | + ) |
| 119 | + outputCol = Param( |
| 120 | + Params._dummy(), |
| 121 | + "outputCol", |
| 122 | + "output column", |
| 123 | + typeConverter=TypeConverters.toString, |
| 124 | + ) |
| 125 | + task = Param( |
| 126 | + Params._dummy(), |
| 127 | + "task", |
| 128 | + "Specifies the task, can be chat or completion.", |
| 129 | + typeConverter=TypeConverters.toString, |
| 130 | + ) |
| 131 | + modelParam = Param( |
| 132 | + Params._dummy(), |
| 133 | + "modelParam", |
| 134 | + "Model Parameters, passed to .generate(). For more details, check https://huggingface.co/docs/transformers/en/main_classes/text_generation#transformers.GenerationConfig", |
| 135 | + ) |
| 136 | + modelConfig = Param( |
| 137 | + Params._dummy(), |
| 138 | + "modelConfig", |
| 139 | + "Model configuration, passed to AutoModelForCausalLM.from_pretrained(). For more details, check https://huggingface.co/docs/transformers/en/model_doc/auto#transformers.AutoModelForCausalLM", |
| 140 | + ) |
| 141 | + cachePath = Param( |
| 142 | + Params._dummy(), |
| 143 | + "cachePath", |
| 144 | + "cache path for the model. A shared location between the workers, could be a lakehouse path", |
| 145 | + typeConverter=TypeConverters.toString, |
| 146 | + ) |
| 147 | + deviceMap = Param( |
| 148 | + Params._dummy(), |
| 149 | + "deviceMap", |
| 150 | + "Specifies a model parameter for the device map. It can also be set with modelParam. Commonly used values include 'auto', 'cuda', or 'cpu'. You may want to check your model documentation for device map", |
| 151 | + typeConverter=TypeConverters.toString, |
| 152 | + ) |
| 153 | + torchDtype = Param( |
| 154 | + Params._dummy(), |
| 155 | + "torchDtype", |
| 156 | + "Specifies a model parameter for the torch dtype. It can be set with modelParam. The most commonly used value is 'auto'. You may want to check your model documentation for torch dtype.", |
| 157 | + typeConverter=TypeConverters.toString, |
| 158 | + ) |
| 159 | + |
| 160 | + @keyword_only |
| 161 | + def __init__( |
| 162 | + self, |
| 163 | + modelName=None, |
| 164 | + inputCol=None, |
| 165 | + outputCol=None, |
| 166 | + task="chat", |
| 167 | + cachePath=None, |
| 168 | + deviceMap=None, |
| 169 | + torchDtype=None, |
| 170 | + ): |
| 171 | + super(HuggingFaceCausalLM, self).__init__() |
| 172 | + self._setDefault( |
| 173 | + modelName=modelName, |
| 174 | + inputCol=inputCol, |
| 175 | + outputCol=outputCol, |
| 176 | + modelParam=_ModelParam(), |
| 177 | + modelConfig=_ModelConfig(), |
| 178 | + task=task, |
| 179 | + cachePath=None, |
| 180 | + deviceMap=None, |
| 181 | + torchDtype=None, |
| 182 | + ) |
| 183 | + kwargs = self._input_kwargs |
| 184 | + self.setParams(**kwargs) |
| 185 | + |
| 186 | + @keyword_only |
| 187 | + def setParams(self): |
| 188 | + kwargs = self._input_kwargs |
| 189 | + return self._set(**kwargs) |
| 190 | + |
| 191 | + def setModelName(self, value): |
| 192 | + return self._set(modelName=value) |
| 193 | + |
| 194 | + def getModelName(self): |
| 195 | + return self.getOrDefault(self.modelName) |
| 196 | + |
| 197 | + def setInputCol(self, value): |
| 198 | + return self._set(inputCol=value) |
| 199 | + |
| 200 | + def getInputCol(self): |
| 201 | + return self.getOrDefault(self.inputCol) |
| 202 | + |
| 203 | + def setOutputCol(self, value): |
| 204 | + return self._set(outputCol=value) |
| 205 | + |
| 206 | + def getOutputCol(self): |
| 207 | + return self.getOrDefault(self.outputCol) |
| 208 | + |
| 209 | + def setModelParam(self, **kwargs): |
| 210 | + param = _ModelParam(**kwargs) |
| 211 | + return self._set(modelParam=param) |
| 212 | + |
| 213 | + def getModelParam(self): |
| 214 | + return self.getOrDefault(self.modelParam) |
| 215 | + |
| 216 | + def setModelConfig(self, **kwargs): |
| 217 | + config = _ModelConfig(**kwargs) |
| 218 | + return self._set(modelConfig=config) |
| 219 | + |
| 220 | + def getModelConfig(self): |
| 221 | + return self.getOrDefault(self.modelConfig) |
| 222 | + |
| 223 | + def setTask(self, value): |
| 224 | + supported_values = ["completion", "chat"] |
| 225 | + if value not in supported_values: |
| 226 | + raise ValueError( |
| 227 | + f"Task must be one of {supported_values}, but got '{value}'." |
| 228 | + ) |
| 229 | + return self._set(task=value) |
| 230 | + |
| 231 | + def getTask(self): |
| 232 | + return self.getOrDefault(self.task) |
| 233 | + |
| 234 | + def setCachePath(self, value): |
| 235 | + return self._set(cachePath=value) |
| 236 | + |
| 237 | + def getCachePath(self): |
| 238 | + return self.getOrDefault(self.cachePath) |
| 239 | + |
| 240 | + def setDeviceMap(self, value): |
| 241 | + return self._set(deviceMap=value) |
| 242 | + |
| 243 | + def getDeviceMap(self): |
| 244 | + return self.getOrDefault(self.deviceMap) |
| 245 | + |
| 246 | + def setTorchDtype(self, value): |
| 247 | + return self._set(torchDtype=value) |
| 248 | + |
| 249 | + def getTorchDtype(self): |
| 250 | + return self.getOrDefault(self.torchDtype) |
| 251 | + |
| 252 | + def getBCObject(self): |
| 253 | + return self.bcObject |
| 254 | + |
| 255 | + def _predict_single_completion(self, prompt, model, tokenizer): |
| 256 | + param = self.getModelParam().get_param() |
| 257 | + inputs = tokenizer(prompt, return_tensors="pt").input_ids |
| 258 | + outputs = model.generate(inputs, **param) |
| 259 | + decoded_output = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0] |
| 260 | + return decoded_output |
| 261 | + |
| 262 | + def _predict_single_chat(self, prompt, model, tokenizer): |
| 263 | + param = self.getModelParam().get_param() |
| 264 | + if isinstance(prompt, list): |
| 265 | + chat = prompt |
| 266 | + else: |
| 267 | + chat = [{"role": "user", "content": prompt}] |
| 268 | + formatted_chat = tokenizer.apply_chat_template( |
| 269 | + chat, tokenize=False, add_generation_prompt=True |
| 270 | + ) |
| 271 | + tokenized_chat = tokenizer( |
| 272 | + formatted_chat, return_tensors="pt", add_special_tokens=False |
| 273 | + ) |
| 274 | + inputs = { |
| 275 | + key: tensor.to(model.device) for key, tensor in tokenized_chat.items() |
| 276 | + } |
| 277 | + merged_inputs = {**inputs, **param} |
| 278 | + outputs = model.generate(**merged_inputs) |
| 279 | + decoded_output = tokenizer.decode( |
| 280 | + outputs[0][inputs["input_ids"].size(1) :], skip_special_tokens=True |
| 281 | + ) |
| 282 | + return decoded_output |
| 283 | + |
| 284 | + def _process_partition(self, iterator, bc_object): |
| 285 | + """Process each partition of the data.""" |
| 286 | + peekable_iterator = _PeekableIterator(iterator) |
| 287 | + try: |
| 288 | + first_row = peekable_iterator.peek() |
| 289 | + except StopIteration: |
| 290 | + return None |
| 291 | + |
| 292 | + if bc_object: |
| 293 | + lc_object = bc_object.value |
| 294 | + model = lc_object.model |
| 295 | + tokenizer = lc_object.tokenizer |
| 296 | + else: |
| 297 | + model_name = self.getModelName() |
| 298 | + model_config = self.getModelConfig().get_config() |
| 299 | + model = AutoModelForCausalLM.from_pretrained(model_name, **model_config) |
| 300 | + tokenizer = AutoTokenizer.from_pretrained(model_name) |
| 301 | + |
| 302 | + task = self.getTask() if self.getTask() else "chat" |
| 303 | + |
| 304 | + for row in peekable_iterator: |
| 305 | + prompt = row[self.getInputCol()] |
| 306 | + if task == "chat": |
| 307 | + result = self._predict_single_chat(prompt, model, tokenizer) |
| 308 | + elif task == "completion": |
| 309 | + result = self._predict_single_completion(prompt, model, tokenizer) |
| 310 | + else: |
| 311 | + raise ValueError( |
| 312 | + f"Unsupported task '{task}'. Supported tasks are 'chat' and 'completion'." |
| 313 | + ) |
| 314 | + row_dict = row.asDict() |
| 315 | + row_dict[self.getOutputCol()] = result |
| 316 | + yield Row(**row_dict) |
| 317 | + |
| 318 | + def _transform(self, dataset): |
| 319 | + if self.getCachePath(): |
| 320 | + bc_object = broadcast_model(self.getCachePath(), self.getModelConfig()) |
| 321 | + else: |
| 322 | + bc_object = None |
| 323 | + input_schema = dataset.schema |
| 324 | + output_schema = StructType( |
| 325 | + input_schema.fields + [StructField(self.getOutputCol(), StringType(), True)] |
| 326 | + ) |
| 327 | + result_rdd = dataset.rdd.mapPartitions( |
| 328 | + lambda partition: self._process_partition(partition, bc_object) |
| 329 | + ) |
| 330 | + result_df = result_rdd.toDF(output_schema) |
| 331 | + return result_df |
0 commit comments