@@ -556,6 +556,23 @@ def get_artifact_mock():
556556 yield get_artifact_mock
557557
558558
559+ @pytest .fixture
560+ def get_artifact_mock_with_metadata ():
561+ with patch .object (MetadataServiceClient , "get_artifact" ) as get_artifact_mock :
562+ get_artifact_mock .return_value = GapicArtifact (
563+ name = _TEST_ARTIFACT_NAME ,
564+ display_name = _TEST_ARTIFACT_ID ,
565+ schema_title = constants .SYSTEM_METRICS ,
566+ schema_version = constants .SCHEMA_VERSIONS [constants .SYSTEM_METRICS ],
567+ metadata = {
568+ google .cloud .aiplatform .metadata .constants ._VERTEX_EXPERIMENT_TRACKING_LABEL : True ,
569+ constants .GCP_ARTIFACT_RESOURCE_NAME_KEY : test_constants .TensorboardConstants ._TEST_TENSORBOARD_RUN_NAME ,
570+ constants ._STATE_KEY : gca_execution .Execution .State .RUNNING ,
571+ },
572+ )
573+ yield get_artifact_mock
574+
575+
559576@pytest .fixture
560577def get_artifact_not_found_mock ():
561578 with patch .object (MetadataServiceClient , "get_artifact" ) as get_artifact_mock :
@@ -2026,6 +2043,27 @@ def test_experiment_run_get_logged_custom_jobs(self, get_custom_job_mock):
20262043 retry = base ._DEFAULT_RETRY ,
20272044 )
20282045
2046+ @pytest .mark .usefixtures (
2047+ "get_metadata_store_mock" ,
2048+ "get_experiment_mock" ,
2049+ "get_experiment_run_mock" ,
2050+ "get_context_mock" ,
2051+ "list_contexts_mock" ,
2052+ "list_executions_mock" ,
2053+ "get_artifact_mock_with_metadata" ,
2054+ "update_context_mock" ,
2055+ )
2056+ def test_update_experiment_run_after_list (
2057+ self ,
2058+ ):
2059+ aiplatform .init (
2060+ project = _TEST_PROJECT ,
2061+ location = _TEST_LOCATION ,
2062+ )
2063+
2064+ experiment_run_list = aiplatform .ExperimentRun .list (experiment = _TEST_EXPERIMENT )
2065+ experiment_run_list [0 ].update_state (gca_execution .Execution .State .FAILED )
2066+
20292067
20302068class TestTensorboard :
20312069 def test_get_or_create_default_tb_with_existing_default (
0 commit comments