Skip to content

Commit 53b76a8

Browse files
feat: Add HuggingFaceCausalLM Transformer for evaluating language models on cluster (#2301)
* poc * poc * rename module * update dependency * add set device type * add Downloader * remove import * update lm * pyarrow version conflict * update transformers version * add dependency * update transformers version * add phi3 test * test missing transformers library * update databricks test * update databricks test * update db library * update doc * format * add broadcast model * temporarily remove horovod for testing * test with previous transformers version * test * test env * test * test * test * test * fix broadcasting * update dependency * test without hadoop client api * update ubuntu version * update E2E phi3 test * exclude phi3 synapse test * bug fix * fix style * format * add gpu test library * update causallm * update model * bug fix * add phi4 to e2e, update transformers version * update env * add dependency * update phi e2e * increase timeout * test run * reduce concurrency on gpu cluster * format * Update core/src/test/scala/com/microsoft/azure/synapse/ml/nbtest/DatabricksUtilities.scala
1 parent 25b49cd commit 53b76a8

File tree

9 files changed

+556
-16
lines changed

9 files changed

+556
-16
lines changed
Lines changed: 331 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,331 @@
1+
import os
2+
3+
from pyspark import keyword_only
4+
from pyspark.ml import Transformer
5+
from pyspark.ml.param.shared import (
6+
HasInputCol,
7+
HasOutputCol,
8+
Param,
9+
Params,
10+
TypeConverters,
11+
)
12+
from pyspark.ml.util import DefaultParamsReadable, DefaultParamsWritable
13+
from pyspark.sql import Row, SparkSession
14+
from pyspark.sql.types import StringType, StructField, StructType
15+
from transformers import AutoModelForCausalLM, AutoTokenizer
16+
17+
18+
class _PeekableIterator:
19+
def __init__(self, iterable):
20+
self._iterator = iter(iterable)
21+
self._cache = []
22+
23+
def __iter__(self):
24+
return self
25+
26+
def __next__(self):
27+
if self._cache:
28+
return self._cache.pop(0)
29+
else:
30+
return next(self._iterator)
31+
32+
def peek(self, n=1):
33+
"""Peek at the next n elements without consuming them."""
34+
while len(self._cache) < n:
35+
try:
36+
self._cache.append(next(self._iterator))
37+
except StopIteration:
38+
break
39+
if n == 1:
40+
return self._cache[0] if self._cache else None
41+
else:
42+
return self._cache[:n]
43+
44+
45+
class _ModelParam:
46+
def __init__(self, **kwargs):
47+
self.param = {}
48+
self.param.update(kwargs)
49+
50+
def get_param(self):
51+
return self.param
52+
53+
54+
class _ModelConfig:
55+
def __init__(self, **kwargs):
56+
self.config = {}
57+
self.config.update(kwargs)
58+
59+
def get_config(self):
60+
return self.config
61+
62+
def set_config(self, **kwargs):
63+
self.config.update(kwargs)
64+
65+
66+
def broadcast_model(cachePath, modelConfig):
67+
bc_computable = _BroadcastableModel(cachePath, modelConfig)
68+
sc = SparkSession.builder.getOrCreate().sparkContext
69+
return sc.broadcast(bc_computable)
70+
71+
72+
class _BroadcastableModel:
73+
def __init__(self, model_path=None, model_config=None):
74+
self.model_path = model_path
75+
self.model = None
76+
self.tokenizer = None
77+
self.model_config = model_config
78+
79+
def load_model(self):
80+
if self.model_path and os.path.exists(self.model_path):
81+
model_config = self.model_config.get_config()
82+
self.model = AutoModelForCausalLM.from_pretrained(
83+
self.model_path, local_files_only=True, **model_config
84+
)
85+
self.tokenizer = AutoTokenizer.from_pretrained(
86+
self.model_path, local_files_only=True
87+
)
88+
else:
89+
raise ValueError(f"Model path {self.model_path} does not exist.")
90+
91+
def __getstate__(self):
92+
return {"model_path": self.model_path, "model_config": self.model_config}
93+
94+
def __setstate__(self, state):
95+
self.model_path = state.get("model_path")
96+
self.model_config = state.get("model_config")
97+
self.model = None
98+
self.tokenizer = None
99+
if self.model_path:
100+
self.load_model()
101+
102+
103+
class HuggingFaceCausalLM(
104+
Transformer, HasInputCol, HasOutputCol, DefaultParamsReadable, DefaultParamsWritable
105+
):
106+
107+
modelName = Param(
108+
Params._dummy(),
109+
"modelName",
110+
"huggingface causal lm model name",
111+
typeConverter=TypeConverters.toString,
112+
)
113+
inputCol = Param(
114+
Params._dummy(),
115+
"inputCol",
116+
"input column",
117+
typeConverter=TypeConverters.toString,
118+
)
119+
outputCol = Param(
120+
Params._dummy(),
121+
"outputCol",
122+
"output column",
123+
typeConverter=TypeConverters.toString,
124+
)
125+
task = Param(
126+
Params._dummy(),
127+
"task",
128+
"Specifies the task, can be chat or completion.",
129+
typeConverter=TypeConverters.toString,
130+
)
131+
modelParam = Param(
132+
Params._dummy(),
133+
"modelParam",
134+
"Model Parameters, passed to .generate(). For more details, check https://huggingface.co/docs/transformers/en/main_classes/text_generation#transformers.GenerationConfig",
135+
)
136+
modelConfig = Param(
137+
Params._dummy(),
138+
"modelConfig",
139+
"Model configuration, passed to AutoModelForCausalLM.from_pretrained(). For more details, check https://huggingface.co/docs/transformers/en/model_doc/auto#transformers.AutoModelForCausalLM",
140+
)
141+
cachePath = Param(
142+
Params._dummy(),
143+
"cachePath",
144+
"cache path for the model. A shared location between the workers, could be a lakehouse path",
145+
typeConverter=TypeConverters.toString,
146+
)
147+
deviceMap = Param(
148+
Params._dummy(),
149+
"deviceMap",
150+
"Specifies a model parameter for the device map. It can also be set with modelParam. Commonly used values include 'auto', 'cuda', or 'cpu'. You may want to check your model documentation for device map",
151+
typeConverter=TypeConverters.toString,
152+
)
153+
torchDtype = Param(
154+
Params._dummy(),
155+
"torchDtype",
156+
"Specifies a model parameter for the torch dtype. It can be set with modelParam. The most commonly used value is 'auto'. You may want to check your model documentation for torch dtype.",
157+
typeConverter=TypeConverters.toString,
158+
)
159+
160+
@keyword_only
161+
def __init__(
162+
self,
163+
modelName=None,
164+
inputCol=None,
165+
outputCol=None,
166+
task="chat",
167+
cachePath=None,
168+
deviceMap=None,
169+
torchDtype=None,
170+
):
171+
super(HuggingFaceCausalLM, self).__init__()
172+
self._setDefault(
173+
modelName=modelName,
174+
inputCol=inputCol,
175+
outputCol=outputCol,
176+
modelParam=_ModelParam(),
177+
modelConfig=_ModelConfig(),
178+
task=task,
179+
cachePath=None,
180+
deviceMap=None,
181+
torchDtype=None,
182+
)
183+
kwargs = self._input_kwargs
184+
self.setParams(**kwargs)
185+
186+
@keyword_only
187+
def setParams(self):
188+
kwargs = self._input_kwargs
189+
return self._set(**kwargs)
190+
191+
def setModelName(self, value):
192+
return self._set(modelName=value)
193+
194+
def getModelName(self):
195+
return self.getOrDefault(self.modelName)
196+
197+
def setInputCol(self, value):
198+
return self._set(inputCol=value)
199+
200+
def getInputCol(self):
201+
return self.getOrDefault(self.inputCol)
202+
203+
def setOutputCol(self, value):
204+
return self._set(outputCol=value)
205+
206+
def getOutputCol(self):
207+
return self.getOrDefault(self.outputCol)
208+
209+
def setModelParam(self, **kwargs):
210+
param = _ModelParam(**kwargs)
211+
return self._set(modelParam=param)
212+
213+
def getModelParam(self):
214+
return self.getOrDefault(self.modelParam)
215+
216+
def setModelConfig(self, **kwargs):
217+
config = _ModelConfig(**kwargs)
218+
return self._set(modelConfig=config)
219+
220+
def getModelConfig(self):
221+
return self.getOrDefault(self.modelConfig)
222+
223+
def setTask(self, value):
224+
supported_values = ["completion", "chat"]
225+
if value not in supported_values:
226+
raise ValueError(
227+
f"Task must be one of {supported_values}, but got '{value}'."
228+
)
229+
return self._set(task=value)
230+
231+
def getTask(self):
232+
return self.getOrDefault(self.task)
233+
234+
def setCachePath(self, value):
235+
return self._set(cachePath=value)
236+
237+
def getCachePath(self):
238+
return self.getOrDefault(self.cachePath)
239+
240+
def setDeviceMap(self, value):
241+
return self._set(deviceMap=value)
242+
243+
def getDeviceMap(self):
244+
return self.getOrDefault(self.deviceMap)
245+
246+
def setTorchDtype(self, value):
247+
return self._set(torchDtype=value)
248+
249+
def getTorchDtype(self):
250+
return self.getOrDefault(self.torchDtype)
251+
252+
def getBCObject(self):
253+
return self.bcObject
254+
255+
def _predict_single_completion(self, prompt, model, tokenizer):
256+
param = self.getModelParam().get_param()
257+
inputs = tokenizer(prompt, return_tensors="pt").input_ids
258+
outputs = model.generate(inputs, **param)
259+
decoded_output = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
260+
return decoded_output
261+
262+
def _predict_single_chat(self, prompt, model, tokenizer):
263+
param = self.getModelParam().get_param()
264+
if isinstance(prompt, list):
265+
chat = prompt
266+
else:
267+
chat = [{"role": "user", "content": prompt}]
268+
formatted_chat = tokenizer.apply_chat_template(
269+
chat, tokenize=False, add_generation_prompt=True
270+
)
271+
tokenized_chat = tokenizer(
272+
formatted_chat, return_tensors="pt", add_special_tokens=False
273+
)
274+
inputs = {
275+
key: tensor.to(model.device) for key, tensor in tokenized_chat.items()
276+
}
277+
merged_inputs = {**inputs, **param}
278+
outputs = model.generate(**merged_inputs)
279+
decoded_output = tokenizer.decode(
280+
outputs[0][inputs["input_ids"].size(1) :], skip_special_tokens=True
281+
)
282+
return decoded_output
283+
284+
def _process_partition(self, iterator, bc_object):
285+
"""Process each partition of the data."""
286+
peekable_iterator = _PeekableIterator(iterator)
287+
try:
288+
first_row = peekable_iterator.peek()
289+
except StopIteration:
290+
return None
291+
292+
if bc_object:
293+
lc_object = bc_object.value
294+
model = lc_object.model
295+
tokenizer = lc_object.tokenizer
296+
else:
297+
model_name = self.getModelName()
298+
model_config = self.getModelConfig().get_config()
299+
model = AutoModelForCausalLM.from_pretrained(model_name, **model_config)
300+
tokenizer = AutoTokenizer.from_pretrained(model_name)
301+
302+
task = self.getTask() if self.getTask() else "chat"
303+
304+
for row in peekable_iterator:
305+
prompt = row[self.getInputCol()]
306+
if task == "chat":
307+
result = self._predict_single_chat(prompt, model, tokenizer)
308+
elif task == "completion":
309+
result = self._predict_single_completion(prompt, model, tokenizer)
310+
else:
311+
raise ValueError(
312+
f"Unsupported task '{task}'. Supported tasks are 'chat' and 'completion'."
313+
)
314+
row_dict = row.asDict()
315+
row_dict[self.getOutputCol()] = result
316+
yield Row(**row_dict)
317+
318+
def _transform(self, dataset):
319+
if self.getCachePath():
320+
bc_object = broadcast_model(self.getCachePath(), self.getModelConfig())
321+
else:
322+
bc_object = None
323+
input_schema = dataset.schema
324+
output_schema = StructType(
325+
input_schema.fields + [StructField(self.getOutputCol(), StringType(), True)]
326+
)
327+
result_rdd = dataset.rdd.mapPartitions(
328+
lambda partition: self._process_partition(partition, bc_object)
329+
)
330+
result_df = result_rdd.toDF(output_schema)
331+
return result_df

core/src/test/scala/com/microsoft/azure/synapse/ml/nbtest/DatabricksGPUTests.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ class DatabricksGPUTests extends DatabricksTestHelper {
1414

1515
val clusterId: String = createClusterInPool(GPUClusterName, AdbGpuRuntime, 2, GpuPoolId)
1616

17-
databricksTestHelper(clusterId, GPULibraries, GPUNotebooks)
17+
databricksTestHelper(clusterId, GPULibraries, GPUNotebooks, 1)
1818

1919
protected override def afterAll(): Unit = {
2020
afterAllHelper(clusterId, GPUClusterName)

core/src/test/scala/com/microsoft/azure/synapse/ml/nbtest/DatabricksUtilities.scala

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -84,9 +84,11 @@ object DatabricksUtilities {
8484
Map("maven" -> Map("coordinates" -> PackageMavenCoordinate, "repo" -> PackageRepository)),
8585
Map("pypi" -> Map("package" -> "pytorch-lightning==1.5.0")),
8686
Map("pypi" -> Map("package" -> "torchvision==0.14.1")),
87-
Map("pypi" -> Map("package" -> "transformers==4.32.1")),
87+
Map("pypi" -> Map("package" -> "transformers==4.49.0")),
88+
Map("pypi" -> Map("package" -> "jinja2==3.1.0")),
8889
Map("pypi" -> Map("package" -> "petastorm==0.12.0")),
89-
Map("pypi" -> Map("package" -> "protobuf==3.20.3"))
90+
Map("pypi" -> Map("package" -> "protobuf==3.20.3")),
91+
Map("pypi" -> Map("package" -> "accelerate==0.26.0"))
9092
).toJson.compactPrint
9193

9294
val RapidsInitScripts: String = List(
@@ -105,12 +107,16 @@ object DatabricksUtilities {
105107
val CPUNotebooks: Seq[File] = ParallelizableNotebooks
106108
.filterNot(_.getAbsolutePath.contains("Fine-tune"))
107109
.filterNot(_.getAbsolutePath.contains("GPU"))
110+
.filterNot(_.getAbsolutePath.contains("Phi Model"))
111+
.filterNot(_.getAbsolutePath.contains("Language Model"))
108112
.filterNot(_.getAbsolutePath.contains("Multivariate Anomaly Detection")) // Deprecated
109113
.filterNot(_.getAbsolutePath.contains("Audiobooks")) // TODO Remove this by fixing auth
110114
.filterNot(_.getAbsolutePath.contains("Art")) // TODO Remove this by fixing performance
111115
.filterNot(_.getAbsolutePath.contains("Explanation Dashboard")) // TODO Remove this exclusion
112116

113-
val GPUNotebooks: Seq[File] = ParallelizableNotebooks.filter(_.getAbsolutePath.contains("Fine-tune"))
117+
val GPUNotebooks: Seq[File] = ParallelizableNotebooks.filter { file =>
118+
file.getAbsolutePath.contains("Fine-tune") || file.getAbsolutePath.contains("Phi Model")
119+
}
114120

115121
val RapidsNotebooks: Seq[File] = ParallelizableNotebooks.filter(_.getAbsolutePath.contains("GPU"))
116122

@@ -427,7 +433,8 @@ abstract class DatabricksTestHelper extends TestBase {
427433

428434
def databricksTestHelper(clusterId: String,
429435
libraries: String,
430-
notebooks: Seq[File]): Unit = {
436+
notebooks: Seq[File],
437+
maxConcurrency: Int = 8): Unit = {
431438

432439
println("Checking if cluster is active")
433440
tryWithRetries(Seq.fill(60 * 20)(1000).toArray) { () =>
@@ -443,7 +450,6 @@ abstract class DatabricksTestHelper extends TestBase {
443450

444451
assert(notebooks.nonEmpty)
445452

446-
val maxConcurrency = 8
447453
val executorService = Executors.newFixedThreadPool(maxConcurrency)
448454
implicit val executionContext: ExecutionContext = ExecutionContext.fromExecutor(executorService)
449455

core/src/test/scala/com/microsoft/azure/synapse/ml/nbtest/SynapseTests.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ class SynapseTests extends TestBase {
4848
.filter(_.getAbsolutePath.endsWith(".py"))
4949
.filterNot(_.getAbsolutePath.contains("Finetune")) // Excluded by design task 1829306
5050
.filterNot(_.getAbsolutePath.contains("GPU"))
51+
.filterNot(_.getAbsolutePath.contains("PhiModel"))
5152
.filterNot(_.getAbsolutePath.contains("VWnativeFormat"))
5253
.filterNot(_.getAbsolutePath.contains("VowpalWabbitMulticlassclassification")) // Wait for Synapse fix
5354
.filterNot(_.getAbsolutePath.contains("Langchain")) // Wait for Synapse fix

0 commit comments

Comments
 (0)