3434
3535
3636# -*- coding: utf-8 -*-
37- # TODO(b/328684671)
3837_EXPECTED_MASK = field_mask_pb2 .FieldMask (paths = ["resource_pools.replica_count" ])
3938
4039# for manual scaling
@@ -241,6 +240,22 @@ def update_persistent_resource_2_pools_mock():
241240 yield update_persistent_resource_2_pools_mock
242241
243242
243+ def cluster_eq (returned_cluster , expected_cluster ):
244+ assert vars (returned_cluster .head_node_type ) == vars (
245+ expected_cluster .head_node_type
246+ )
247+ assert vars (returned_cluster .worker_node_types [0 ]) == vars (
248+ expected_cluster .worker_node_types [0 ]
249+ )
250+ assert (
251+ returned_cluster .cluster_resource_name == expected_cluster .cluster_resource_name
252+ )
253+ assert returned_cluster .python_version == expected_cluster .python_version
254+ assert returned_cluster .ray_version == expected_cluster .ray_version
255+ assert returned_cluster .network == expected_cluster .network
256+ assert returned_cluster .state == expected_cluster .state
257+
258+
244259@pytest .mark .usefixtures ("google_auth_mock" , "get_project_number_mock" )
245260class TestClusterManagement :
246261 def setup_method (self ):
@@ -315,6 +330,7 @@ def test_create_ray_cluster_1_pool_gpu_with_labels_success(
315330 network = tc .ProjectConstants .TEST_VPC_NETWORK ,
316331 cluster_name = tc .ClusterConstants .TEST_VERTEX_RAY_PR_ID ,
317332 labels = tc .ClusterConstants .TEST_LABELS ,
333+ enable_metrics_collection = False ,
318334 )
319335
320336 assert tc .ClusterConstants .TEST_VERTEX_RAY_PR_ADDRESS == cluster_name
@@ -465,21 +481,7 @@ def test_get_ray_cluster_success(self, get_persistent_resource_1_pool_mock):
465481 )
466482
467483 get_persistent_resource_1_pool_mock .assert_called_once ()
468-
469- assert vars (cluster .head_node_type ) == vars (
470- tc .ClusterConstants .TEST_CLUSTER .head_node_type
471- )
472- assert vars (cluster .worker_node_types [0 ]) == vars (
473- tc .ClusterConstants .TEST_CLUSTER .worker_node_types [0 ]
474- )
475- assert (
476- cluster .cluster_resource_name
477- == tc .ClusterConstants .TEST_CLUSTER .cluster_resource_name
478- )
479- assert cluster .python_version == tc .ClusterConstants .TEST_CLUSTER .python_version
480- assert cluster .ray_version == tc .ClusterConstants .TEST_CLUSTER .ray_version
481- assert cluster .network == tc .ClusterConstants .TEST_CLUSTER .network
482- assert cluster .state == tc .ClusterConstants .TEST_CLUSTER .state
484+ cluster_eq (cluster , tc .ClusterConstants .TEST_CLUSTER )
483485
484486 def test_get_ray_cluster_with_custom_image_success (
485487 self , get_persistent_resource_2_pools_custom_image_mock
@@ -489,27 +491,7 @@ def test_get_ray_cluster_with_custom_image_success(
489491 )
490492
491493 get_persistent_resource_2_pools_custom_image_mock .assert_called_once ()
492-
493- assert vars (cluster .head_node_type ) == vars (
494- tc .ClusterConstants .TEST_CLUSTER_CUSTOM_IMAGE .head_node_type
495- )
496- assert vars (cluster .worker_node_types [0 ]) == vars (
497- tc .ClusterConstants .TEST_CLUSTER_CUSTOM_IMAGE .worker_node_types [0 ]
498- )
499- assert (
500- cluster .cluster_resource_name
501- == tc .ClusterConstants .TEST_CLUSTER_CUSTOM_IMAGE .cluster_resource_name
502- )
503- assert (
504- cluster .python_version
505- == tc .ClusterConstants .TEST_CLUSTER_CUSTOM_IMAGE .python_version
506- )
507- assert (
508- cluster .ray_version
509- == tc .ClusterConstants .TEST_CLUSTER_CUSTOM_IMAGE .ray_version
510- )
511- assert cluster .network == tc .ClusterConstants .TEST_CLUSTER_CUSTOM_IMAGE .network
512- assert cluster .state == tc .ClusterConstants .TEST_CLUSTER_CUSTOM_IMAGE .state
494+ cluster_eq (cluster , tc .ClusterConstants .TEST_CLUSTER_CUSTOM_IMAGE )
513495
514496 @pytest .mark .usefixtures ("get_persistent_resource_exception_mock" )
515497 def test_get_ray_cluster_error (self ):
@@ -526,42 +508,9 @@ def test_list_ray_clusters_success(self, list_persistent_resources_mock):
526508 list_persistent_resources_mock .assert_called_once ()
527509
528510 # first ray cluster
529- assert vars (clusters [0 ].head_node_type ) == vars (
530- tc .ClusterConstants .TEST_CLUSTER .head_node_type
531- )
532- assert vars (clusters [0 ].worker_node_types [0 ]) == vars (
533- tc .ClusterConstants .TEST_CLUSTER .worker_node_types [0 ]
534- )
535- assert (
536- clusters [0 ].cluster_resource_name
537- == tc .ClusterConstants .TEST_CLUSTER .cluster_resource_name
538- )
539- assert (
540- clusters [0 ].python_version
541- == tc .ClusterConstants .TEST_CLUSTER .python_version
542- )
543- assert clusters [0 ].ray_version == tc .ClusterConstants .TEST_CLUSTER .ray_version
544- assert clusters [0 ].network == tc .ClusterConstants .TEST_CLUSTER .network
545- assert clusters [0 ].state == tc .ClusterConstants .TEST_CLUSTER .state
546-
511+ cluster_eq (clusters [0 ], tc .ClusterConstants .TEST_CLUSTER )
547512 # second ray cluster
548- assert vars (clusters [1 ].head_node_type ) == vars (
549- tc .ClusterConstants .TEST_CLUSTER_2 .head_node_type
550- )
551- assert vars (clusters [1 ].worker_node_types [0 ]) == vars (
552- tc .ClusterConstants .TEST_CLUSTER_2 .worker_node_types [0 ]
553- )
554- assert (
555- clusters [1 ].cluster_resource_name
556- == tc .ClusterConstants .TEST_CLUSTER_2 .cluster_resource_name
557- )
558- assert (
559- clusters [1 ].python_version
560- == tc .ClusterConstants .TEST_CLUSTER_2 .python_version
561- )
562- assert clusters [1 ].ray_version == tc .ClusterConstants .TEST_CLUSTER_2 .ray_version
563- assert clusters [1 ].network == tc .ClusterConstants .TEST_CLUSTER_2 .network
564- assert clusters [1 ].state == tc .ClusterConstants .TEST_CLUSTER_2 .state
513+ cluster_eq (clusters [1 ], tc .ClusterConstants .TEST_CLUSTER_2 )
565514
566515 def test_list_ray_clusters_initialized_success (
567516 self , get_project_number_mock , list_persistent_resources_mock
0 commit comments