|
14 | 14 | "3. Minimize a biophysical energy function\n", |
15 | 15 | "4. Use experimental screening data to guide designs with a regression model\n", |
16 | 16 | "\n", |
17 | | - "As long as your scoring function takes a protein as input and outputs a single score, you can use it to guide designs. To accomplish this, we use an implementation of derivative-free guidance inspired by Soft Value-Based Decoding described in [Li, et al 2024](https://arxiv.org/abs/2408.08252).\n", |
| 17 | + "As long as your scoring function takes a protein as input and outputs a single score, you can use it to guide designs. To accomplish this, we use an implementation of derivative-free guidance inspired by Soft Value-Based Decoding described in [Li, et al 2024](https://arxiv.org/abs/2408.08252) and constrained optimization using the Modified Differential Method of Multipliers from [Platt & Barr 1987](https://proceedings.neurips.cc/paper_files/paper/1987/file/a1126573153ad7e9f44ba80e99316482-Paper.pdf)\n", |
18 | 18 | "\n", |
19 | 19 | "In this notebook we will walk through a few examples to illustrate how to use guided generation. \n", |
20 | 20 | "\n", |
21 | 21 | "1. Guide towards high pTM for improved generation quality\n", |
22 | 22 | "2. Generate a protein with no cysteine (C) residues\n", |
23 | | - "3. Maximize protein globularity by minimizing the radius of gyration\n", |
| 23 | + "3. Maximize protein globularity by minimizing the radius of gyration, while keeping pTM high\n", |
24 | 24 | "\n" |
25 | 25 | ] |
26 | 26 | }, |
|
49 | 49 | "source": [ |
50 | 50 | "import biotite.structure as bs\n", |
51 | 51 | "import py3Dmol\n", |
| 52 | + "\n", |
52 | 53 | "from esm.sdk.api import ESMProtein, GenerationConfig\n", |
53 | 54 | "from esm.sdk.experimental import ESM3GuidedDecoding, GuidedDecodingScoringFunction" |
54 | 55 | ] |
|
269 | 270 | "metadata": {}, |
270 | 271 | "outputs": [], |
271 | 272 | "source": [ |
| 273 | + "# Start from a fully masked protein\n", |
| 274 | + "PROTEIN_LENGTH = 256\n", |
| 275 | + "starting_protein = ESMProtein(sequence=\"_\" * PROTEIN_LENGTH)\n", |
| 276 | + "\n", |
| 277 | + "# Call guided_generate\n", |
272 | 278 | "no_cysteine_protein = no_cysteine_guided_decoding.guided_generate(\n", |
273 | 279 | " protein=starting_protein,\n", |
274 | 280 | " num_decoding_steps=len(starting_protein) // 8,\n", |
|
302 | 308 | "source": [ |
303 | 309 | "## Maximize Globularity\n", |
304 | 310 | "\n", |
305 | | - "We use the radius of gyration as a proxy to maximize globularity, we also encourage generations to have high pTM" |
| 311 | + "We use the radius of gyration as a proxy to maximize globularity, and we will also encourage generations to have high pTM by using constraints" |
| 312 | + ] |
| 313 | + }, |
| 314 | + { |
| 315 | + "cell_type": "code", |
| 316 | + "execution_count": null, |
| 317 | + "metadata": {}, |
| 318 | + "outputs": [], |
| 319 | + "source": [ |
| 320 | + "from esm.sdk.experimental import (\n", |
| 321 | + " ConstraintType,\n", |
| 322 | + " ESM3GuidedDecodingWithConstraints,\n", |
| 323 | + " GenerationConstraint,\n", |
| 324 | + ")" |
306 | 325 | ] |
307 | 326 | }, |
308 | 327 | { |
|
313 | 332 | "source": [ |
314 | 333 | "class RadiousOfGyrationScoringFunction(GuidedDecodingScoringFunction):\n", |
315 | 334 | " def __call__(self, protein: ESMProtein) -> float:\n", |
| 335 | + " # Use the negative radius of gyration as the score to maximize\n", |
316 | 336 | " score = -1 * self.radius_of_gyration(protein)\n", |
317 | 337 | "\n", |
318 | | - " assert protein.ptm is not None, \"Protein must have pTM scores to be scored\"\n", |
319 | | - " if protein.ptm < 0.5:\n", |
320 | | - " # Penalize proteins with low pTM scores\n", |
321 | | - " score = score * 2\n", |
| 338 | + " # Re-scale the score to be in a similar magnitude as pTM\n", |
| 339 | + " score = score / 100.0\n", |
322 | 340 | "\n", |
323 | 341 | " return score\n", |
324 | 342 | "\n", |
|
335 | 353 | "metadata": {}, |
336 | 354 | "outputs": [], |
337 | 355 | "source": [ |
338 | | - "radius_guided_decoding = ESM3GuidedDecoding(\n", |
339 | | - " client=model, scoring_function=RadiousOfGyrationScoringFunction()\n", |
| 356 | + "# Constrain generation to have pTM > 0.75\n", |
| 357 | + "ptm_constraint = GenerationConstraint(\n", |
| 358 | + " scoring_function=PTMScoringFunction(),\n", |
| 359 | + " constraint_type=ConstraintType.GREATER_EQUAL,\n", |
| 360 | + " value=0.75,\n", |
| 361 | + ")\n", |
| 362 | + "\n", |
| 363 | + "radius_guided_decoding = ESM3GuidedDecodingWithConstraints(\n", |
| 364 | + " client=model,\n", |
| 365 | + " scoring_function=RadiousOfGyrationScoringFunction(),\n", |
| 366 | + " constraints=[ptm_constraint], # Add list of constraints\n", |
| 367 | + " damping=1.0, # Damping factor for the MMDM algorithm\n", |
| 368 | + " learning_rate=10.0, # Learning rate for the MMDM algorithm\n", |
340 | 369 | ")" |
341 | 370 | ] |
342 | 371 | }, |
|
346 | 375 | "metadata": {}, |
347 | 376 | "outputs": [], |
348 | 377 | "source": [ |
| 378 | + "# Start from a fully masked protein\n", |
| 379 | + "PROTEIN_LENGTH = 256\n", |
| 380 | + "starting_protein = ESMProtein(sequence=\"_\" * PROTEIN_LENGTH)\n", |
| 381 | + "\n", |
| 382 | + "# Call guided_generate\n", |
349 | 383 | "radius_guided_protein = radius_guided_decoding.guided_generate(\n", |
350 | 384 | " protein=starting_protein,\n", |
351 | 385 | " num_decoding_steps=len(starting_protein) // 8,\n", |
|
359 | 393 | "metadata": {}, |
360 | 394 | "outputs": [], |
361 | 395 | "source": [ |
| 396 | + "# Visualize the trajectory of the constrained generation\n", |
| 397 | + "radius_guided_decoding.visualize_latest_trajectory()" |
| 398 | + ] |
| 399 | + }, |
| 400 | + { |
| 401 | + "cell_type": "code", |
| 402 | + "execution_count": null, |
| 403 | + "metadata": {}, |
| 404 | + "outputs": [], |
| 405 | + "source": [ |
| 406 | + "# Visualize the generated protein\n", |
362 | 407 | "view = py3Dmol.view(width=800, height=400)\n", |
363 | 408 | "view.addModel(radius_guided_protein.to_pdb_string(), \"pdb\")\n", |
364 | 409 | "view.setStyle({\"cartoon\": {\"color\": \"spectrum\"}})\n", |
365 | 410 | "view.zoomTo()" |
366 | 411 | ] |
| 412 | + }, |
| 413 | + { |
| 414 | + "cell_type": "code", |
| 415 | + "execution_count": null, |
| 416 | + "metadata": {}, |
| 417 | + "outputs": [], |
| 418 | + "source": [ |
| 419 | + "# Check pTM\n", |
| 420 | + "radius_guided_protein.ptm" |
| 421 | + ] |
367 | 422 | } |
368 | 423 | ], |
369 | 424 | "metadata": { |
|
0 commit comments