Skip to content

Commit ed65494

Browse files
bvliucopybara-github
authored andcommitted
Refactor RelationalDbSpec registration to use CLOUD and ENGINE.
PiperOrigin-RevId: 904720629
1 parent e1f76e5 commit ed65494

5 files changed

Lines changed: 35 additions & 15 deletions

File tree

perfkitbenchmarker/configs/benchmark_config_spec.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -811,14 +811,18 @@ def Decode(self, value, component_full_name, flag_values):
811811
# LoadProvider is required for resources to be registered.
812812
_LoadProvider(relational_db_config, flag_values)
813813

814+
cloud = relational_db_config.get('cloud')
815+
if not cloud or flag_values['cloud'].present:
816+
cloud = flag_values.cloud
817+
814818
if 'engine' in relational_db_config:
815819
if flag_values['db_engine'].present:
816820
db_spec_class = relational_db_spec.GetRelationalDbSpecClass(
817-
flag_values['db_engine'].value
821+
cloud, flag_values['db_engine'].value
818822
)
819823
else:
820824
db_spec_class = relational_db_spec.GetRelationalDbSpecClass(
821-
relational_db_config['engine']
825+
cloud, relational_db_config['engine']
822826
)
823827
else:
824828
raise errors.Config.InvalidValue(

perfkitbenchmarker/providers/aws/aws_aurora_dsql_db.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,8 @@
5252
class AwsAuroraDsqlSpec(relational_db_spec.RelationalDbSpec):
5353
"""Configurable options for AWS Aurora DSQL."""
5454

55-
SERVICE_TYPE = 'aurora-dsql'
55+
CLOUD = 'AWS'
56+
ENGINE = [sql_engine_utils.AURORA_DSQL_POSTGRES]
5657

5758
@classmethod
5859
def _GetOptionDecoderConstructions(cls):

perfkitbenchmarker/providers/gcp/gcp_spanner.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -206,7 +206,11 @@ class SpannerAsyncOperationInProgressError(Exception):
206206
class SpannerSpec(relational_db_spec.RelationalDbSpec):
207207
"""Configurable options of a Spanner instance."""
208208

209-
SERVICE_TYPE = 'spanner'
209+
CLOUD = 'GCP'
210+
ENGINE = [
211+
sql_engine_utils.SPANNER_GOOGLESQL,
212+
sql_engine_utils.SPANNER_POSTGRES,
213+
]
210214

211215
spanner_instance_id: str
212216
spanner_database_id: str

perfkitbenchmarker/relational_db_spec.py

Lines changed: 4 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -33,16 +33,9 @@
3333
_NONE_OK = {'default': None, 'none_ok': True}
3434

3535

36-
def GetRelationalDbSpecClass(engine):
37-
"""Get the RelationalDbSpec class corresponding to 'engine'."""
38-
if engine in [
39-
sql_engine_utils.SPANNER_GOOGLESQL,
40-
sql_engine_utils.SPANNER_POSTGRES,
41-
]:
42-
return spec.GetSpecClass(RelationalDbSpec, SERVICE_TYPE='spanner')
43-
if engine == sql_engine_utils.AURORA_DSQL_POSTGRES:
44-
return spec.GetSpecClass(RelationalDbSpec, SERVICE_TYPE='aurora-dsql')
45-
return RelationalDbSpec
36+
def GetRelationalDbSpecClass(cloud, engine):
37+
"""Get the RelationalDbSpec class corresponding to 'cloud' and 'engine'."""
38+
return spec.GetSpecClass(RelationalDbSpec, CLOUD=cloud, ENGINE=engine)
4639

4740

4841
class RelationalDbSpec(freeze_restore_spec.FreezeRestoreSpec):
@@ -63,7 +56,7 @@ class RelationalDbSpec(freeze_restore_spec.FreezeRestoreSpec):
6356
"""
6457

6558
SPEC_TYPE = 'RelationalDbSpec'
66-
SPEC_ATTRS = ['SERVICE_TYPE']
59+
SPEC_ATTRS = ['CLOUD', 'ENGINE']
6760

6861
cloud: str
6962
engine: str

tests/relational_db_spec_test.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,9 @@
1515

1616
import unittest
1717
from absl import flags
18+
from absl.testing import parameterized
1819
from perfkitbenchmarker import errors
20+
from perfkitbenchmarker import providers
1921
from perfkitbenchmarker import relational_db
2022
from perfkitbenchmarker import relational_db_spec
2123
from perfkitbenchmarker.providers.gcp import gce_virtual_machine
@@ -318,6 +320,22 @@ def testReverseOrderedDbReplicaZonesFlag(self):
318320
result.vm_groups['servers_replicas'].vm_spec.zone, 'us-central1-d'
319321
)
320322

323+
# pyformat: disable
324+
@parameterized.named_parameters([
325+
('SpannerGoogleSql', 'GCP', 'spanner-googlesql', 'SpannerSpec'),
326+
('SpannerPostgres', 'GCP', 'spanner-postgres', 'SpannerSpec'),
327+
('AuroraDsql', 'AWS', 'aurora-dsql-postgres', 'AwsAuroraDsqlSpec'),
328+
('RdsMysql', 'AWS', 'mysql', 'RelationalDbSpec'),
329+
('CloudSqlPostgres', 'GCP', 'postgres', 'RelationalDbSpec'),
330+
('AzureFlexibleServer', 'Azure', 'flexible-server-postgres', 'RelationalDbSpec'),
331+
('AzureSqlManagedInstance', 'Azure', 'sqlserver', 'RelationalDbSpec'),
332+
])
333+
# pyformat: enable
334+
def testGetRelationalDbSpecClass(self, cloud, engine, expected_class_name):
335+
providers.LoadProvider(cloud, True)
336+
actual_class = relational_db_spec.GetRelationalDbSpecClass(cloud, engine)
337+
self.assertEqual(actual_class.__name__, expected_class_name)
338+
321339

322340
if __name__ == '__main__':
323341
unittest.main()

0 commit comments

Comments
 (0)