|
44 | 44 | CONST_DEFAULT_WINDOWS_VMS_VM_SIZE, |
45 | 45 | CONST_MANAGED_CLUSTER_SKU_NAME_AUTOMATIC, |
46 | 46 | CONST_SSH_ACCESS_LOCALUSER, |
| 47 | + CONST_GPU_DRIVER_INSTALL, |
47 | 48 | CONST_GPU_DRIVER_NONE, |
| 49 | + CONST_GPU_MANAGEMENT_MODE_MANAGED, |
| 50 | + CONST_GPU_MANAGEMENT_MODE_UNMANAGED, |
48 | 51 | CONST_NODEPOOL_MODE_MANAGEDSYSTEM, |
49 | 52 | CONST_NODEPOOL_MODE_MACHINES, |
50 | 53 | ) |
@@ -587,6 +590,27 @@ def get_enable_artifact_streaming(self) -> bool: |
587 | 590 | enable_artifact_streaming = self.agentpool.artifact_streaming_profile.enabled |
588 | 591 | return enable_artifact_streaming |
589 | 592 |
|
| 593 | + def get_enable_managed_gpu(self) -> Union[bool, None]: |
| 594 | + """Obtain the value of enable_managed_gpu. |
| 595 | + :return: bool |
| 596 | + """ |
| 597 | + |
| 598 | + # read the original value passed by the command |
| 599 | + enable_managed_gpu = self.raw_param.get("enable_managed_gpu") |
| 600 | + |
| 601 | + # In create mode, try to read the property value corresponding to the parameter from the `agentpool` object |
| 602 | + if self.decorator_mode == DecoratorMode.CREATE: |
| 603 | + if ( |
| 604 | + self.agentpool and |
| 605 | + self.agentpool.gpu_profile is not None and |
| 606 | + self.agentpool.gpu_profile.nvidia is not None and |
| 607 | + self.agentpool.gpu_profile.nvidia.management_mode is not None |
| 608 | + ): |
| 609 | + enable_managed_gpu = ( |
| 610 | + self.agentpool.gpu_profile.nvidia.management_mode == CONST_GPU_MANAGEMENT_MODE_MANAGED |
| 611 | + ) |
| 612 | + return enable_managed_gpu |
| 613 | + |
590 | 614 | def get_pod_ip_allocation_mode(self: bool = False) -> Union[str, None]: |
591 | 615 | """Get the value of pod_ip_allocation_mode. |
592 | 616 | :return: str or None |
@@ -1276,6 +1300,21 @@ def set_up_artifact_streaming(self, agentpool: AgentPool) -> AgentPool: |
1276 | 1300 | agentpool.artifact_streaming_profile.enabled = True |
1277 | 1301 | return agentpool |
1278 | 1302 |
|
| 1303 | + def set_up_managed_gpu(self, agentpool: AgentPool) -> AgentPool: |
| 1304 | + """Set up managed GPU property for the AgentPool object.""" |
| 1305 | + self._ensure_agentpool(agentpool) |
| 1306 | + |
| 1307 | + enable_managed_gpu = self.context.get_enable_managed_gpu() |
| 1308 | + |
| 1309 | + if enable_managed_gpu: |
| 1310 | + if agentpool.gpu_profile is None: |
| 1311 | + agentpool.gpu_profile = self.models.GPUProfile() # pylint: disable=no-member |
| 1312 | + if agentpool.gpu_profile.nvidia is None: |
| 1313 | + agentpool.gpu_profile.nvidia = self.models.NvidiaGPUProfile() # pylint: disable=no-member |
| 1314 | + agentpool.gpu_profile.nvidia.management_mode = CONST_GPU_MANAGEMENT_MODE_MANAGED |
| 1315 | + agentpool.gpu_profile.driver = CONST_GPU_DRIVER_INSTALL |
| 1316 | + return agentpool |
| 1317 | + |
1279 | 1318 | def set_up_ssh_access(self, agentpool: AgentPool) -> AgentPool: |
1280 | 1319 | self._ensure_agentpool(agentpool) |
1281 | 1320 |
|
@@ -1510,6 +1549,8 @@ def construct_agentpool_profile_preview(self) -> AgentPool: |
1510 | 1549 | agentpool = self.set_up_init_taints(agentpool) |
1511 | 1550 | # set up artifact streaming |
1512 | 1551 | agentpool = self.set_up_artifact_streaming(agentpool) |
| 1552 | + # set up managed gpu |
| 1553 | + agentpool = self.set_up_managed_gpu(agentpool) |
1513 | 1554 | # set up skip_gpu_driver_install |
1514 | 1555 | agentpool = self.set_up_skip_gpu_driver_install(agentpool) |
1515 | 1556 | # set up gpu profile |
@@ -1704,6 +1745,29 @@ def update_artifact_streaming(self, agentpool: AgentPool) -> AgentPool: |
1704 | 1745 | agentpool.artifact_streaming_profile.enabled = True |
1705 | 1746 | return agentpool |
1706 | 1747 |
|
| 1748 | + def update_managed_gpu(self, agentpool: AgentPool) -> AgentPool: |
| 1749 | + """Update managed GPU property for the AgentPool object. |
| 1750 | + :return: the AgentPool object |
| 1751 | + """ |
| 1752 | + self._ensure_agentpool(agentpool) |
| 1753 | + |
| 1754 | + enable_managed_gpu = self.context.get_enable_managed_gpu() |
| 1755 | + if enable_managed_gpu is None: |
| 1756 | + return agentpool |
| 1757 | + |
| 1758 | + if enable_managed_gpu: |
| 1759 | + if agentpool.gpu_profile is None: |
| 1760 | + agentpool.gpu_profile = self.models.GPUProfile() # pylint: disable=no-member |
| 1761 | + if agentpool.gpu_profile.nvidia is None: |
| 1762 | + agentpool.gpu_profile.nvidia = self.models.NvidiaGPUProfile() # pylint: disable=no-member |
| 1763 | + agentpool.gpu_profile.nvidia.management_mode = CONST_GPU_MANAGEMENT_MODE_MANAGED |
| 1764 | + agentpool.gpu_profile.driver = CONST_GPU_DRIVER_INSTALL |
| 1765 | + else: |
| 1766 | + if agentpool.gpu_profile and agentpool.gpu_profile.nvidia: |
| 1767 | + agentpool.gpu_profile.nvidia.management_mode = CONST_GPU_MANAGEMENT_MODE_UNMANAGED |
| 1768 | + |
| 1769 | + return agentpool |
| 1770 | + |
1707 | 1771 | def update_os_sku(self, agentpool: AgentPool) -> AgentPool: |
1708 | 1772 | self._ensure_agentpool(agentpool) |
1709 | 1773 |
|
@@ -1828,6 +1892,9 @@ def update_agentpool_profile_preview(self, agentpools: List[AgentPool] = None) - |
1828 | 1892 | # update artifact streaming |
1829 | 1893 | agentpool = self.update_artifact_streaming(agentpool) |
1830 | 1894 |
|
| 1895 | + # update managed gpu |
| 1896 | + agentpool = self.update_managed_gpu(agentpool) |
| 1897 | + |
1831 | 1898 | # update secure boot |
1832 | 1899 | agentpool = self.update_secure_boot(agentpool) |
1833 | 1900 |
|
|
0 commit comments